分析
可以发现每种颜色一定是一条链,则整棵树相当于被剖分成了若干条链。
考虑 LCT 维护,则 1
操作相当于 access(x)
,2
操作相当于询问一条路径上的链数,3
操作相当于求子树中到根路径上链数最大的点。
考虑将 2
操作差分,则只需要用一个数据结构维护每个点到根路径的链数,要求支持区间求最大值。
考虑 access(x)
时对每个点到根路径上的链数的影响。对于当前的 $x$,它的重儿子所在子树的答案会加上 $1$(因为这条链被断开了),它的新重儿子所在子树的答案会减去 $1$(因为这条链连上了)。用 DFS 序+线段树维护即可。
剩下的问题是如何求某个点的重儿子所在子树的根节点。我们改一下 findroot
,让它跳左儿子之前不 access
和 splay
,这样子 findroot(rs(x))
找到的就是 $x$ 的重儿子所在子树的根节点了。
因为魔改了 findroot
所以复杂度不会算。
代码
// ====================================
// author: M_sea
// website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(x".in","r",stdin); freopen(x".out","w",stdout)
using namespace std;
typedef long long ll;
int read() {
int X=0,w=1; char c=getchar();
while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
return X*w;
}
const int N=100000+10;
int n,m;
vector<int> E[N];
int dep[N],fa[N],sz[N],hson[N],top[N];
int dfn[N],low[N],pos[N],tim=0;
void dfs1(int u,int f) {
dep[u]=dep[fa[u]=f]+1,sz[u]=1;
for (int v:E[u]) {
if (v==f) continue;
dfs1(v,u); sz[u]+=sz[v];
if (sz[v]>sz[hson[u]]) hson[u]=v;
}
}
void dfs2(int u,int anc) {
top[u]=anc,dfn[u]=++tim,pos[tim]=u;
if (hson[u]) dfs2(hson[u],anc);
for (int v:E[u])
if (v!=fa[u]&&v!=hson[u]) dfs2(v,v);
low[u]=tim;
}
int getlca(int u,int v) {
while (top[u]!=top[v]) {
if (dep[top[u]]<dep[top[v]]) swap(u,v);
u=fa[top[u]];
}
return dep[u]<dep[v]?u:v;
}
#define ls (o<<1)
#define rs (o<<1|1)
int maxv[N<<2],addv[N<<2];
void pushup(int o) { maxv[o]=max(maxv[ls],maxv[rs]); }
void pushdown(int o) {
if (addv[o]) {
maxv[ls]+=addv[o],addv[ls]+=addv[o];
maxv[rs]+=addv[o],addv[rs]+=addv[o];
addv[o]=0;
}
}
void build(int o,int l,int r) {
if (l==r) { maxv[o]=dep[pos[l]]; return; }
int mid=(l+r)>>1;
build(ls,l,mid),build(rs,mid+1,r);
pushup(o);
}
void modify(int o,int l,int r,int ql,int qr,int w) {
if (ql<=l&&r<=qr) { maxv[o]+=w,addv[o]+=w; return; }
int mid=(l+r)>>1; pushdown(o);
if (ql<=mid) modify(ls,l,mid,ql,qr,w);
if (qr>mid) modify(rs,mid+1,r,ql,qr,w);
pushup(o);
}
int query(int o,int l,int r,int ql,int qr) {
if (ql<=l&&r<=qr) return maxv[o];
int mid=(l+r)>>1,res=0; pushdown(o);
if (ql<=mid) res=max(res,query(ls,l,mid,ql,qr));
if (qr>mid) res=max(res,query(rs,mid+1,r,ql,qr));
pushup(o); return res;
}
#undef ls
#undef rs
namespace LCT {
#define ls(o) ch[o][0]
#define rs(o) ch[o][1]
int fa[N],ch[N][2];
int nroot(int x) { return ls(fa[x])==x||rs(fa[x])==x; }
void rotate(int x) {
int y=fa[x],z=fa[y],k=(x==rs(y)),w=ch[x][!k];
if (nroot(y)) ch[z][y==rs(z)]=x;
ch[x][!k]=y,ch[y][k]=w;
if (w) fa[w]=y; fa[y]=x,fa[x]=z;
}
void splay(int x) {
while (nroot(x)) {
int y=fa[x],z=fa[y];
if (nroot(y)) rotate((x==ls(y))^(y==ls(z))?x:y);
rotate(x);
}
}
int findroot(int x) {while (ls(x)) x=ls(x); return x; }
void access(int x) {
for (int y=0;x;x=fa[y=x]) {
splay(x);
if (rs(x)) {
int t=findroot(rs(x));
modify(1,1,n,dfn[t],low[t],1);
}
rs(x)=y;
if (rs(x)) {
int t=findroot(rs(x));
modify(1,1,n,dfn[t],low[t],-1);
}
}
}
#undef ls
#undef rs
} // namespace LCT
int main() {
n=read(),m=read();
for (int i=1;i<n;++i) {
int u=read(),v=read();
E[u].emplace_back(v),E[v].emplace_back(u);
}
dfs1(1,0),dfs2(1,1); build(1,1,n);
for (int i=1;i<=n;++i) LCT::fa[i]=fa[i];
while (m--) {
int op=read();
if (op==1) LCT::access(read());
if (op==2) {
int u=read(),v=read(),t=getlca(u,v);
printf("%d\n",query(1,1,n,dfn[u],dfn[u])
+query(1,1,n,dfn[v],dfn[v])
-(query(1,1,n,dfn[t],dfn[t])<<1)+1);
}
if (op==3) {
int u=read();
printf("%d\n",query(1,1,n,dfn[u],low[u]));
}
}
return 0;
}