Luogu

分析

首先发现,两个不同的连续区间对应两个不同的删除方案。

那么就可以统计合法区间的数量。

枚举右端点 $r$ ,考虑左端点 $l$ 怎么样才合法。

设颜色 $k$ 最右边的出现位置为 $max[k]$ ,最左边的出现位置为 $min[k]$ 。

显然所有 $max[k]$大于 $r$ 的颜色都会被删去,所以有在 $(l,r)$ 中不存在一定会被删去的颜色。

也就是要找到一个 $max[c[j]]\geq r$ 且 $r-j$ 最小的 $j$ ,然后左端点就可以取 $(j,r]$ 。

另外还有 $max[j]<i$ 的情况,此时 $l$ 一定不在 $\big(min[j],max[j]\big]$ 中。


于是最后的算法如下:

从左往右枚举 $r$ 。

如果 $max[c[r]]=r$ ,那么将 $\big(min[j],max[j]\big]$ 赋值成 $1$ ,表示不能选这些点。这个可以用线段树做到。

然后求出一个 $max[c[l]]\geq r$ 且 $r-l$ 最小的 $l$ 。可以将之前的所有点加到一个栈中,然后一直 pop 直到找到。

然后我们把这个点作为 $l$ ,那么当前的答案就是 $r-j-(j,r]\text{中}1\text{的个数}$ 。

将每个 $r$ 的答案累加就是总的答案。


一个小细节:请使用 for / fill 而不是 memset

代码

//It is made by M_sea
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
#define re register
typedef long long ll;
using namespace std;

inline int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int N=300000+10;

int n,c[N];
int mn[N],mx[N];
struct node { int c,p; } sta[N];
int top=0;

struct segment_tree {
    int sumv[N<<2],setv[N<<2];
#define ls (o<<1)
#define rs (o<<1|1)
    inline void pushup(int o) { sumv[o]=sumv[ls]+sumv[rs]; }
    inline void pushdown(int o,int l,int r) {
        if (setv[o]!=-1) {
            int mid=(l+r)>>1;
            setv[ls]=setv[rs]=setv[o];
            sumv[ls]=setv[o]*(mid-l+1),sumv[rs]=setv[o]*(r-mid);
            setv[o]=-1;
        }
    }

    inline void build(int o,int l,int r) {
        sumv[o]=0,setv[o]=-1;
        if (l==r) return;
        int mid=(l+r)>>1;
        build(ls,l,mid),build(rs,mid+1,r);
    }
    inline void assign(int o,int l,int r,int ql,int qr,int v) {
        if (ql<=l&&r<=qr) { setv[o]=v,sumv[o]=v*(r-l+1); return; }
        int mid=(l+r)>>1; pushdown(o,l,r);
        if (ql<=mid) assign(ls,l,mid,ql,qr,v);
        if (qr>mid) assign(rs,mid+1,r,ql,qr,v);
        pushup(o);
    }
    inline int query(int o,int l,int r,int ql,int qr) {
        if (ql<=l&&r<=qr) return sumv[o];
        int mid=(l+r)>>1,res=0; pushdown(o,l,r);
        if (ql<=mid) res+=query(ls,l,mid,ql,qr);
        if (qr>mid) res+=query(rs,mid+1,r,ql,qr);
        pushup(o); return res;
    }

#undef ls
#undef rs
} T;

int main() {
    int cases=read();
    while (cases--) {
        n=read(),top=0; ll ans=0; T.build(1,0,n);
        fill(mn,mn+n+1,1e9),fill(mx,mx+n+1,0);
        for (re int i=1;i<=n;++i) c[i]=read();
        for (re int i=1;i<=n;++i) mn[c[i]]=min(mn[c[i]],i); 
        for (re int i=1;i<=n;++i) mx[c[i]]=max(mx[c[i]],i);
        for (re int r=1;r<=n;++r) {
            if (r==mx[c[r]]&&mn[c[r]]<mx[c[r]])
                T.assign(1,0,n,mn[c[r]]+1,mx[c[r]],1);
            else sta[++top]=(node){c[r],r};
            while (top&&mx[sta[top].c]<=r) --top;
            int l=top?sta[top].p:0;
            if (l<r) ans+=r-l-T.query(1,0,n,l+1,r);
        }
        printf("%lld\n",ans);
    }
    return 0;
}
最后修改:2021 年 03 月 23 日 05 : 55 PM