分析
题目中的式子等价于
$$
\frac{1}{2}\left(\operatorname{dep}(u)+\operatorname{dep}(v)+\operatorname{dis}(u,v)\right)-\operatorname{dep}'(\operatorname{LCA}'(u,v)))
$$
考虑对第一棵树边分治,对两边的点计算到对应的端点的距离 $d_u$,上面那个式子就变成了
$$
\frac{1}{2}(\operatorname{dep}(u)+\operatorname{dep}(v)+d_u+d_v+w)-\operatorname{dep}'(\operatorname{LCA}'(u,v))
$$
我们把一边的点染成黑色,另一边的点染成白色,然后对第二棵树上的这些点建虚树,则一个点在第二棵树上成为 LCA 当且仅当在它不同的两棵子树中选择了两个颜色不同的点。于是树形 DP 一下即可。
虽然是 $\mathcal{O}(n\log^2 n)$ 的但是跑得挺快的,不需要卡常就可以过。
代码
// ====================================
// author: M_sea
// website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin); freopen(#x".out","w",stdout)
using namespace std;
typedef long long ll;
int read() {
int X=0,w=1; char c=getchar();
while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
return X*w;
}
const int N=1466664+10;
int n; ll ans=-5e18;
vector<int> v[2];
ll w[N]; int tp[N];
namespace T2 {
struct edge { int v,w,nxt; } e[N<<1];
int head[N],ecnt;
void addEdge(int u,int v,int w) {
e[++ecnt]=(edge){v,w,head[u]},head[u]=ecnt;
}
int dep[N],fa[N],sz[N],hson[N],top[N]; ll dis[N];
int dfn[N],tim=0;
void dfs1(int u,int f) {
dep[u]=dep[fa[u]=f]+1,sz[u]=1;
for (int i=head[u];i;i=e[i].nxt) {
int v=e[i].v,w=e[i].w; if (v==f) continue;
dis[v]=dis[u]+w,dfs1(v,u),sz[u]+=sz[v];
if (sz[v]>sz[hson[u]]) hson[u]=v;
}
}
void dfs2(int u,int anc) {
top[u]=anc,dfn[u]=++tim;
if (hson[u]) dfs2(hson[u],anc);
for (int i=head[u];i;i=e[i].nxt)
if (e[i].v!=fa[u]&&e[i].v!=hson[u]) dfs2(e[i].v,e[i].v);
}
int LCA(int u,int v) {
while (top[u]!=top[v]) {
if (dep[top[u]]<dep[top[v]]) swap(u,v);
u=fa[top[u]];
}
return dep[u]<dep[v]?u:v;
}
int vis[N],p[N],tot,sta[N],stop;
ll dp[N][2];
bool cmp(int x,int y) { return dfn[x]<dfn[y]; }
void dfs(int u,int we) {
dp[u][0]=dp[u][1]=-2e18;
if (vis[u]) dp[u][tp[u]]=w[u];
for (int i=head[u];i;i=e[i].nxt) {
int v=e[i].v; dfs(v,we);
ans=max(ans,dp[u][0]+dp[v][1]+we-2*dis[u]);
ans=max(ans,dp[u][1]+dp[v][0]+we-2*dis[u]);
dp[u][0]=max(dp[u][0],dp[v][0]);
dp[u][1]=max(dp[u][1],dp[v][1]);
}
head[u]=vis[u]=0;
}
void solve(int we) {
ecnt=0,tot=stop=0;
for (int i:v[0]) p[++tot]=i,vis[i]=1;
for (int i:v[1]) p[++tot]=i,vis[i]=1;
sort(p+1,p+tot+1,cmp);
if (!vis[1]) sta[++stop]=1;
for (int i=1;i<=tot;++i) {
int u=p[i],f=LCA(u,sta[stop]);
while (stop>1&&dep[sta[stop-1]]>=dep[f])
addEdge(sta[stop-1],sta[stop],0),--stop;
if (f!=sta[stop]) addEdge(f,sta[stop],0),sta[stop]=f;
sta[++stop]=u;
}
while (stop>1) addEdge(sta[stop-1],sta[stop],0),--stop;
dfs(1,we);
}
}
namespace T1 {
struct edge { int v,w,nxt; } e[N<<1];
int head[N],ecnt=1;
void addEdge(int u,int v,int w) {
e[++ecnt]=(edge){v,w,head[u]},head[u]=ecnt;
}
int tot;
vector<int> son[N]; int fw[N]; ll dis[N];
void dfs(int u,int f) {
for (int i=head[u];i;i=e[i].nxt) {
int v=e[i].v,w=e[i].w; if (v==f) continue;
fw[v]=w,son[u].emplace_back(v),dis[v]=dis[u]+w; dfs(v,u);
}
}
void transform() {
memset(head,0,sizeof(head)),ecnt=1;
for (int i=1;i<=tot;++i) {
if (son[i].size()<=2) {
for (int j:son[i]) addEdge(i,j,fw[j]),addEdge(j,i,fw[j]);
} else {
int ls=++tot; addEdge(i,ls,0),addEdge(ls,i,0);
int rs=++tot; addEdge(i,rs,0),addEdge(rs,i,0);
for (int j=0;j<son[i].size();++j) {
if (j&1) son[ls].emplace_back(son[i][j]);
else son[rs].emplace_back(son[i][j]);
}
}
}
}
int rt,mxsz,sz[N],vis[N]; ll d[N];
void getroot(int u,int f,int sumsz) {
sz[u]=1;
for (int i=head[u];i;i=e[i].nxt) {
int v=e[i].v; if (vis[i>>1]||v==f) continue;
getroot(v,u,sumsz),sz[u]+=sz[v];
if (max(sz[v],sumsz-sz[v])<mxsz)
rt=i,mxsz=max(sz[v],sumsz-sz[v]);
}
}
void dfs(int u,int f,int id) {
if (u<=n) v[id].emplace_back(u),tp[u]=id;
for (int i=head[u];i;i=e[i].nxt)
if (!vis[i>>1]&&e[i].v!=f) d[e[i].v]=d[u]+e[i].w,dfs(e[i].v,u,id);
}
void solve(int u,int size) {
mxsz=size+1,getroot(u,0,size);
if (mxsz==size+1) return;
vis[rt>>1]=1,d[e[rt].v]=d[e[rt^1].v]=0,v[0].clear(),v[1].clear();
dfs(e[rt].v,0,0),dfs(e[rt^1].v,0,1);
for (int i:v[0]) w[i]=d[i]+dis[i];
for (int i:v[1]) w[i]=d[i]+dis[i];
T2::solve(e[rt].w);
for (int i:v[0]) w[i]=0;
for (int i:v[1]) w[i]=0;
int id=rt,su=sz[e[rt].v],sv=size-sz[e[rt].v];
solve(e[id].v,su),solve(e[id^1].v,sv);
}
}
int main() {
n=T1::tot=read();
for (int i=1;i<n;++i) {
int u=read(),v=read(),w=read();
T1::addEdge(u,v,w),T1::addEdge(v,u,w);
}
for (int i=1;i<n;++i) {
int u=read(),v=read(),w=read();
T2::addEdge(u,v,w),T2::addEdge(v,u,w);
}
T2::dfs1(1,0),T2::dfs2(1,1); memset(T2::head,0,sizeof(T2::head));
T1::dfs(1,0),T1::transform(); T1::solve(1,T1::tot);
for (int i=1;i<=n;++i) ans=max(ans,2*(T1::dis[i]-T2::dis[i]));
printf("%lld\n",ans>>1);
return 0;
}