Luogu

LOJ

分析

显然每个数作为众数的情况是相互独立的,因此我们考虑枚举众数 $k$,然后计算有多少个区间中 $k$ 的出现次数超过一半。

不妨把等于 $k$ 的数看做 $1$,不等于 $k$ 的数看做 $-1$,那么我们相当于要求有多少个区间的和大于 $0$。这个东西等价于前缀和的逆序对数。

但是直接做的话当 $A_i$ 的数量比较多时显然会 TLE,所以我们需要一些优化。

考虑到 $A_i$ 的数量比较多时,每次数列中的 $1$ 会比较少而 $-1$ 会比较多,而连续的一段 $-1$ 对答案的贡献显然是 $0$。因此我们可以考虑优化一下求顺序对的过程,每次将一段区间加入。

假设 $[x,y]$ 为连续的一段 $-1$,前缀和数组中 $[x,y]$ 的最小值和最大值分别为 $l,r$,那么这段区间的贡献为
$$
\sum_{i=l}^r\sum_{j=-\infty}^{i-1}cnt_j
$$
变形得到
$$
(r-l+1)\sum_{i=\infty}^{l-1}cnt_i+r\times\sum_{i=l}^{r-1}cnt_i-\sum_{i=l}^{r-1}i\times cnt_i
$$
开一棵线段树维护一下 $\sum cnt_i$ 和 $\sum i\times cnt_i$ 即可。

对于每一个枚举的众数 $k$,时间复杂度为 $\mathcal{O}(cnt_k\log n)$,因此总时间复杂度为 $\mathcal{O}(n\log n)$。但是被两只 log 的做法爆踩 /kk

代码

// ===================================
//   author: M_sea
//   website: http://m-sea-blog.com/
// ===================================
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <vector>
#include <cmath>
#define re register
using namespace std;
typedef long long ll;

inline int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int N=500000+10;

int n,m;
int a[N]; vector<int> v[N];

#define ls (o<<1)
#define rs (o<<1|1)
ll sumv[N<<2][2],addv[N<<2];
inline void pushup(int o) {
    sumv[o][0]=sumv[ls][0]+sumv[rs][0];
    sumv[o][1]=sumv[ls][1]+sumv[rs][1];
}
inline void pushdown(int o,int l,int r) {
    if (addv[o]) { int mid=(l+r)>>1;
        sumv[ls][0]+=addv[o]*(mid-l+1);
        sumv[ls][1]+=addv[o]*(mid+l-m)*(mid-l+1)/2;
        addv[ls]+=addv[o];
        sumv[rs][0]+=addv[o]*(r-mid);
        sumv[rs][1]+=addv[o]*(mid+r+1-m)*(r-mid)/2;
        addv[rs]+=addv[o];
        addv[o]=0;
    }
}
inline void modify(int o,int l,int r,int ql,int qr,int w) {
    if (ql>qr) return;
    if (ql<=l&&r<=qr) {
        addv[o]+=w;
        sumv[o][0]+=(r-l+1)*w;
        sumv[o][1]+=1ll*w*(l+r-m)*(r-l+1)/2;
        return;
    }
    int mid=(l+r)>>1; pushdown(o,l,r);
    if (ql<=mid) modify(ls,l,mid,ql,qr,w);
    if (qr>mid) modify(rs,mid+1,r,ql,qr,w);
    pushup(o);
}
inline ll query(int o,int l,int r,int ql,int qr,int op) {
    if (ql>qr) return 0;
    if (ql<=l&&r<=qr) return sumv[o][op];
    int mid=(l+r)>>1; ll res=0; pushdown(o,l,r);
    if (ql<=mid) res+=query(ls,l,mid,ql,qr,op);
    if (qr>mid) res+=query(rs,mid+1,r,ql,qr,op);
    pushup(o); return res;
}
#undef ls
#undef rs

int main() {
    n=read(),read(),m=n<<1; ll ans=0;
    for (re int i=1;i<=n;++i) a[i]=read(),v[a[i]].push_back(i);
    for (re int i=0;i<n;++i) v[i].push_back(n+1);
    for (re int i=0;i<n;++i) {
        if (v[i].size()==1) continue;
        for (re int j=0,l,r=0;j<v[i].size();++j,r=l+1) {
            l=2*j+1-v[i][j];
            ans+=(r-l+1)*query(1,1,m,1,l-1+n,0)
                +r*query(1,1,m,l+n,r-1+n,0)
                -query(1,1,m,l+n,r-1+n,1);
            modify(1,1,m,l+n,r+n,1);
        }
        for (re int j=0,l,r=0;j<v[i].size();++j,r=l+1) {
            l=2*j+1-v[i][j];
            modify(1,1,m,l+n,r+n,-1);
        }
    }
    printf("%lld\n",ans);
    return 0;
}
最后修改:2021 年 03 月 24 日 02 : 57 PM