分析
考虑计算每种颜色的贡献,即计算每种颜色包含在多少条路径中。
这个东西似乎还是不好算,可以考虑计算每种颜色不包含在多少条路径中。
我们把这种颜色染成白色,其它颜色染成黑色,那么只有黑色连通块内部的点对才满足条件。
于是我们需要支持修改颜色、求黑色连通块大小平方和,使用 QTREE6 一题的方法把颜色放到父边上,LCT 维护子树信息即可。具体细节可以看一看官方题解中的图示。
具体实现时,一开始维护一颗全黑的树,然后把每种颜色单独拿出来做一遍,然后再改回去,就可以算出每种颜色的贡献了。每种颜色每次算的是对一段区间的贡献,所以需要差分一下。
代码
// ====================================
// author: M_sea
// website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin); freopen(#x".out","w",stdout)
#define debug(...) fprintf(stderr,__VA_ARGS__)
#define sqr(x) (1ll*x*x)
using namespace std;
typedef long long ll;
int read() {
int X=0,w=1; char c=getchar();
while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
return X*w;
}
const int N=400000+10;
int n,m,c[N];
vector<int> E[N];
vector<pair<int,int>> q[N];
ll ans[N];
int fa[N];
void dfs(int u,int f) {
fa[u]=f;
for (int v:E[u]) if (v!=f) dfs(v,u);
}
namespace L { // Link-Cut Tree
#define ls(o) ch[o][0]
#define rs(o) ch[o][1]
int ch[N][2],fa[N],sz[N],vsz[N]; ll vsz2[N],sum=0;
bool nroot(int x) { return ls(fa[x])==x||rs(fa[x])==x; }
void pushup(int o) { sz[o]=sz[ls(o)]+sz[rs(o)]+vsz[o]+1; }
void rotate(int x) {
int y=fa[x],z=fa[y],k=x==rs(y),w=ch[x][!k];
if (nroot(y)) ch[z][y==rs(z)]=x;
ch[x][!k]=y,ch[y][k]=w;
if (w) fa[w]=y; fa[y]=x,fa[x]=z;
pushup(y);
}
void splay(int x) {
while (nroot(x)) {
int y=fa[x],z=fa[y];
if (nroot(y)) rotate((x==rs(y))^(y==rs(z))?x:y);
rotate(x);
}
pushup(x);
}
void access(int x) {
for (int y=0;x;x=fa[y=x]) {
splay(x);
vsz[x]+=sz[rs(x)],vsz2[x]+=sqr(sz[rs(x)]);
rs(x)=y;
vsz[x]-=sz[rs(x)],vsz2[x]-=sqr(sz[rs(x)]);
}
}
int findroot(int x) {
access(x),splay(x);
while (ls(x)) x=ls(x);
splay(x); return x;
}
void link(int x,int y) {
splay(x); sum-=vsz2[x]+sqr(sz[rs(x)]);
int z=findroot(y); splay(z),sum-=sqr(sz[rs(z)]);
splay(y),fa[x]=y,vsz[y]+=sz[x],vsz2[y]+=sqr(sz[x]),pushup(y);
access(x),splay(z),sum+=sqr(sz[rs(z)]);
}
void cut(int x,int y) {
access(x); sum+=vsz2[x];
int z=findroot(y); access(x),splay(z); sum-=sqr(sz[rs(z)]);
splay(x),ls(x)=fa[ls(x)]=0,pushup(x);
splay(z),sum+=sqr(sz[rs(z)]);
}
#undef ls
#undef rs
}
int o[N];
void calc(int c) {
ll lst=0;
for (auto i:q[c]) {
int u=i.first,t=i.second;
if (o[u]) L::link(u,fa[u]);
else L::cut(u,fa[u]);
o[u]^=1;
ans[t]+=sqr(n)-L::sum-lst;
lst=sqr(n)-L::sum;
}
for (auto i:q[c]) {
int u=i.first;
if (o[u]) L::link(u,fa[u]),o[u]=0;
}
}
int main() {
n=read(),m=read();
for (int i=1;i<=n;++i) c[i]=read();
for (int i=1;i<=n;++i) q[c[i]].emplace_back(i,0);
for (int i=1;i<n;++i) {
int u=read(),v=read();
E[u].emplace_back(v),E[v].emplace_back(u);
}
for (int i=1;i<=m;++i) {
int u=read(),w=read();
q[c[u]].emplace_back(u,i);
c[u]=w;
q[c[u]].emplace_back(u,i);
}
dfs(1,n+1);
for (int i=1;i<=n+1;++i) L::sz[i]=1;
for (int i=1;i<=n;++i) L::link(i,fa[i]);
for (int i=1;i<=n;++i) calc(i);
for (int i=1;i<=m;++i) ans[i]+=ans[i-1];
for (int i=0;i<=m;++i) printf("%lld\n",ans[i]);
return 0;
}