分析
通过手玩一些数据可以发现,我们一定会从一个有宝藏的点出发,而且按照 DFS 序遍历所有有宝藏的点会最优。
不妨设现有的有宝藏的点按 DFS 序排序后为 $a_1,a_2,a_3,\cdots,a_k$,那么我们要求的就是 $\operatorname{dis}(a_1,a_2)+\operatorname{dis}(a_2,a_3)+\cdots\operatorname{dis}(a_k,a_1)$。
考虑加入一个点 $u$ 后答案的变化量。设 $a$ 中 $u$ 的DFS 序的前驱为 $L$,后继为 $R$,那么变化量为 $\operatorname{dis}(L,u)+\operatorname{dis}(u,R)-\operatorname{dis}(L,R)$。删除类似。
于是用 std::set
维护所有有宝藏的点就好了。
代码
// ===================================
// author: M_sea
// website: http://m-sea-blog.com/
// ===================================
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <set>
#define re register
#define int long long
using namespace std;
inline int read() {
int X=0,w=1; char c=getchar();
while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
return X*w;
}
const int N=100000+10;
int n,m;
struct edge { int v,w,nxt; } e[N<<1];
int head[N];
inline void addEdge(int u,int v,int w) {
static int cnt=0;
e[++cnt]=(edge){v,w,head[u]},head[u]=cnt;
}
int dep[N],fa[N],sz[N],hson[N],top[N],dis[N];
int dfn[N],pos[N],tim=0;
inline void dfs1(int u,int f) {
dep[u]=dep[fa[u]=f]+1,sz[u]=1;
for (re int i=head[u];i;i=e[i].nxt) {
int v=e[i].v,w=e[i].w; if (v==f) continue;
dis[v]=dis[u]+w,dfs1(v,u),sz[u]+=sz[v];
if (sz[v]>sz[hson[u]]) hson[u]=v;
}
}
inline void dfs2(int u,int anc) {
dfn[u]=++tim,pos[tim]=u,top[u]=anc;
if (hson[u]) dfs2(hson[u],anc);
for (re int i=head[u];i;i=e[i].nxt)
if (e[i].v!=fa[u]&&e[i].v!=hson[u]) dfs2(e[i].v,e[i].v);
}
inline int getlca(int u,int v) {
while (top[u]!=top[v]) {
if (dep[top[u]]<dep[top[v]]) swap(u,v);
u=fa[top[u]];
}
return dep[u]<dep[v]?u:v;
}
inline int getdis(int u,int v) {
return dis[u]+dis[v]-(dis[getlca(u,v)]<<1);
}
set<int> S; set<int>::iterator it;
int in[N];
signed main() {
n=read(),m=read();
for (re int i=1;i<n;++i) {
int u=read(),v=read(),w=read();
addEdge(u,v,w),addEdge(v,u,w);
}
dfs1(1,0),dfs2(1,1);
int ans=0;
while (m--) {
int u=read();
if (!in[u]) S.insert(dfn[u]);
it=S.lower_bound(dfn[u]);
int L=pos[it==S.begin()?*--S.end():*--it];
it=S.upper_bound(dfn[u]);
int R=pos[it==S.end()?*S.begin():*it];
if (in[u]) S.erase(dfn[u]);
int dlt=getdis(L,u)+getdis(u,R)-getdis(L,R);
in[u]?ans-=dlt:ans+=dlt; in[u]^=1;
printf("%lld\n",ans);
}
return 0;
}