分析
显然每个数作为众数的情况是相互独立的,因此我们考虑枚举众数 $k$,然后计算有多少个区间中 $k$ 的出现次数超过一半。
不妨把等于 $k$ 的数看做 $1$,不等于 $k$ 的数看做 $-1$,那么我们相当于要求有多少个区间的和大于 $0$。这个东西等价于前缀和的逆序对数。
但是直接做的话当 $A_i$ 的数量比较多时显然会 TLE,所以我们需要一些优化。
考虑到 $A_i$ 的数量比较多时,每次数列中的 $1$ 会比较少而 $-1$ 会比较多,而连续的一段 $-1$ 对答案的贡献显然是 $0$。因此我们可以考虑优化一下求顺序对的过程,每次将一段区间加入。
假设 $[x,y]$ 为连续的一段 $-1$,前缀和数组中 $[x,y]$ 的最小值和最大值分别为 $l,r$,那么这段区间的贡献为
$$
\sum_{i=l}^r\sum_{j=-\infty}^{i-1}cnt_j
$$
变形得到
$$
(r-l+1)\sum_{i=\infty}^{l-1}cnt_i+r\times\sum_{i=l}^{r-1}cnt_i-\sum_{i=l}^{r-1}i\times cnt_i
$$
开一棵线段树维护一下 $\sum cnt_i$ 和 $\sum i\times cnt_i$ 即可。
对于每一个枚举的众数 $k$,时间复杂度为 $\mathcal{O}(cnt_k\log n)$,因此总时间复杂度为 $\mathcal{O}(n\log n)$。但是被两只 log 的做法爆踩 /kk
代码
// ===================================
// author: M_sea
// website: http://m-sea-blog.com/
// ===================================
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <vector>
#include <cmath>
#define re register
using namespace std;
typedef long long ll;
inline int read() {
int X=0,w=1; char c=getchar();
while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
return X*w;
}
const int N=500000+10;
int n,m;
int a[N]; vector<int> v[N];
#define ls (o<<1)
#define rs (o<<1|1)
ll sumv[N<<2][2],addv[N<<2];
inline void pushup(int o) {
sumv[o][0]=sumv[ls][0]+sumv[rs][0];
sumv[o][1]=sumv[ls][1]+sumv[rs][1];
}
inline void pushdown(int o,int l,int r) {
if (addv[o]) { int mid=(l+r)>>1;
sumv[ls][0]+=addv[o]*(mid-l+1);
sumv[ls][1]+=addv[o]*(mid+l-m)*(mid-l+1)/2;
addv[ls]+=addv[o];
sumv[rs][0]+=addv[o]*(r-mid);
sumv[rs][1]+=addv[o]*(mid+r+1-m)*(r-mid)/2;
addv[rs]+=addv[o];
addv[o]=0;
}
}
inline void modify(int o,int l,int r,int ql,int qr,int w) {
if (ql>qr) return;
if (ql<=l&&r<=qr) {
addv[o]+=w;
sumv[o][0]+=(r-l+1)*w;
sumv[o][1]+=1ll*w*(l+r-m)*(r-l+1)/2;
return;
}
int mid=(l+r)>>1; pushdown(o,l,r);
if (ql<=mid) modify(ls,l,mid,ql,qr,w);
if (qr>mid) modify(rs,mid+1,r,ql,qr,w);
pushup(o);
}
inline ll query(int o,int l,int r,int ql,int qr,int op) {
if (ql>qr) return 0;
if (ql<=l&&r<=qr) return sumv[o][op];
int mid=(l+r)>>1; ll res=0; pushdown(o,l,r);
if (ql<=mid) res+=query(ls,l,mid,ql,qr,op);
if (qr>mid) res+=query(rs,mid+1,r,ql,qr,op);
pushup(o); return res;
}
#undef ls
#undef rs
int main() {
n=read(),read(),m=n<<1; ll ans=0;
for (re int i=1;i<=n;++i) a[i]=read(),v[a[i]].push_back(i);
for (re int i=0;i<n;++i) v[i].push_back(n+1);
for (re int i=0;i<n;++i) {
if (v[i].size()==1) continue;
for (re int j=0,l,r=0;j<v[i].size();++j,r=l+1) {
l=2*j+1-v[i][j];
ans+=(r-l+1)*query(1,1,m,1,l-1+n,0)
+r*query(1,1,m,l+n,r-1+n,0)
-query(1,1,m,l+n,r-1+n,1);
modify(1,1,m,l+n,r+n,1);
}
for (re int j=0,l,r=0;j<v[i].size();++j,r=l+1) {
l=2*j+1-v[i][j];
modify(1,1,m,l+n,r+n,-1);
}
}
printf("%lld\n",ans);
return 0;
}