AtCoder

分析

我们先把所有只有只会从一个出口出去的机器人删掉,这些机器人不影响答案。

对于剩下的机器人,记它到左边第一个出口的距离为 $a$ ,到右边第一个出口的距离为 $b$ 。那么每个机器人可以看成一个点 $(a,b)$。

设 $x$ 为所有机器人往左移动的最远点到初始位置的距离,$y$ 为所有机器人往右移动的最远点到初始位置的距离。那么每次相当于把 $(x,y)$ 变成 $(x+1,y)$ 或者 $(x,y+1)$ 。

可以发现当 $x=a_i$ 时,机器人 $i$ 会走左边第一个出口出去;当 $y=b_i$ 时,机器人 $i$ 会走右边第一个出口出去。

于是可以看成从原点开始走,每次可以往上或者往右走,那么最后走的折线的上下两边分别代表走左边的和走右边的集合。

而两个方案不同当且仅当某一边的点集不同,于是我们可以强制令折线为所有下方点集的上轮廓线。

设 $f_i$ 表示折线最后经过的点是 $i$ 的方案数,则有
$$
f[i]=1+\sum_{x_j<x_i,y_j<y_i}f[j]
$$

树状数组优化即可。

代码

// ===================================
//   author: M_sea
//   website: http://m-sea-blog.com/
// ===================================
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
#define re register
using namespace std;

inline int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int N=100000+10;
const int mod=1e9+7;

int n,m,cnt,top;
int a[N],b[N],S[N];
pair<int,int> p[N];
int f[N];

inline int cmp(pair<int,int> a,pair<int,int> b) {
    if (a.first!=b.first) return a.first<b.first;
    else return a.second>b.second;
}

int c[N];
inline int lowbit(int x) { return x&-x; }
inline int query(int x) { int res=0;
    for (;x;x-=lowbit(x)) res=(res+c[x])%mod;
    return res;
}
inline void add(int x,int y) {
    for (;x<=top;x+=lowbit(x)) c[x]=(c[x]+y)%mod;
}

int main() {
    n=read(),m=read();
    for (re int i=1;i<=n;++i) a[i]=read();
    for (re int i=1;i<=m;++i) b[i]=read();
    for (re int i=1;i<=n;++i) {
        if (a[i]<=b[1]||a[i]>=b[m]) continue;
        int x=lower_bound(b+1,b+m+1,a[i])-b;
        if (b[x]==a[i]) continue;
        p[++cnt]=make_pair(a[i]-b[x-1],b[x]-a[i]);
        S[++top]=b[x]-a[i];
    }
    sort(S+1,S+top+1); top=unique(S+1,S+top+1)-S-1;
    for (re int i=1;i<=cnt;++i)
        p[i].second=lower_bound(S+1,S+top+1,p[i].second)-S;
    sort(p+1,p+cnt+1,cmp); cnt=unique(p+1,p+cnt+1)-p-1;
    f[0]=1;
    for (re int i=1;i<=cnt;++i)
        f[i]=query(p[i].second-1)+1,add(p[i].second,f[i]);
    int ans=0;
    for (re int i=0;i<=cnt;++i) ans=(ans+f[i])%mod;
    printf("%d\n",ans);
    return 0;
}
最后修改:2021 年 03 月 23 日 09 : 28 PM