分析
显然一个询问 $(l_1,r_1,l_2,r_2,l_3,r_3)$ 的答案是 $(r_1-l_1+1)+(r_2-l_2+1)+(r_3-l_3+1)-3\times size$ ,这里的 $size$ 表示三个区间内出现了多少个公共的颜色。
那么只需要考虑如何求 $size$。
首先对所有数离散化,令它离散化后的值为小于等于它的数的个数。
然后当加入一个值 $p$ 的时候,把 bitset 中的 $p-cnt_p$ 置为 $1$ 。这样子相邻两个值 $a$ 和 $b$ 的差就是 $a$ 的出现次数。这样子就可以通过取交来得到 $size$ 了。
我们只需要使用莫队提取出每个区间对应的 bitset 即可。
然而开 $10^5$ 个长为 $10^5$ 的 bitset 会 MLE,所以需要把询问拆成 $3$ 组来做。
代码
// ===================================
// author: M_sea
// website: http://m-sea-blog.com/
// ===================================
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <bitset>
#include <cmath>
#define re register
using namespace std;
inline int read() {
int X=0,w=1; char c=getchar();
while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
return X*w;
}
const int N=100000+10;
const int M=33335;
int n,m,block,tot;
int bl[N],ans[N],cnt[N],vis[N];
int a[N],b[N],L[4][N],R[4][N];
bitset<N> bit[M+10],now;
struct query { int l,r,id; } q[N];
bool operator <(query a,query b) {
if (bl[a.l]!=bl[b.l]) return a.l<b.l;
else return (bl[a.l]&1)?(a.r<b.r):(a.r>b.r);
}
inline void add(int x) { now[x+cnt[x]]=1,++cnt[x]; }
inline void del(int x) { --cnt[x],now[x+cnt[x]]=0; }
inline void solve(int ql,int qr) {
memset(cnt,0,sizeof(cnt)),memset(vis,0,sizeof(vis)); tot=0;
for (re int i=ql;i<=qr;++i)
for (re int j=0;j<3;++j) {
q[++tot]=(query){L[j][i],R[j][i],i};
ans[i]+=R[j][i]-L[j][i]+1;
}
sort(q+1,q+tot+1);
int l=1,r=0; now.reset();
for (re int i=1;i<=tot;++i) {
while (r<q[i].r) add(a[++r]);
while (l>q[i].l) add(a[--l]);
while (r>q[i].r) del(a[r--]);
while (l<q[i].l) del(a[l++]);
if (!vis[q[i].id-ql+1]) vis[q[i].id-ql+1]=1,bit[q[i].id-ql+1]=now;
else bit[q[i].id-ql+1]&=now;
}
for (re int i=ql;i<=qr;++i) ans[i]-=3*bit[i-ql+1].count();
}
int main() {
n=read(),m=read(),block=sqrt(n);
for (re int i=1;i<=n;++i) bl[i]=(i-1)/block+1;
for (re int i=1;i<=n;++i) a[i]=b[i]=read();
sort(b+1,b+n+1);
for (re int i=1;i<=n;++i) a[i]=lower_bound(b+1,b+n+1,a[i])-b;
for (re int i=1;i<=m;++i)
for (re int j=0;j<3;++j)
L[j][i]=read(),R[j][i]=read();
for (re int i=1;i<=m;i+=M) solve(i,min(m,i+M-1));
for (re int i=1;i<=m;++i) printf("%d\n",ans[i]);
return 0;
}
1 条评论
orz M_sea Ynoi大聚聚