分析
考虑反着看整个过程,那么相当于每个点有 $y_i$ 只老鼠和 $x_i$ 个洞,所有老鼠都必须进洞。
$u,v$ 匹配的代价是 $dep_u+dep_v-2dep_{\operatorname{LCA}(u,v)}$。因此我们开两个堆维护子树中老鼠和洞的 $dep$,每次把儿子的堆合并,然后取堆顶匹配,再把反悔操作加入堆中即可。
然而这样并不能保证所有老鼠都进洞,所以我们把老鼠的 $dep$ 减去 $+\infty$,最后再把答案加上 $+\infty\times\sum y_i$ 即可。
这样子当匹配的老鼠和洞来自同一棵子树答案可能会不对,但是这样子一定不会是最优解,所以不会有问题。
因为一对匹配的老鼠和洞不会同时反悔,所以只会操作 $X=\sum x_i$ 次,所以时间复杂度 $\mathcal{O}(X\log X)$。可以记一个 pair
表示代价和出现次数从而做到 $\mathcal{O}(n\log n)$。
代码
因为懒所以用了 __gnu_pbds::priority_queue
(
// ====================================
// author: M_sea
// website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#include <ext/pb_ds/priority_queue.hpp>
#define file(x) freopen(#x".in","r",stdin); freopen(#x".out","w",stdout)
#define mp make_pair
using namespace std;
typedef long long ll;
int read() {
int X=0,w=1; char c=getchar();
while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
return X*w;
}
const int N=250000+10;
const ll inf=1e12;
int n,x[N],y[N]; ll ans=0;
vector<pair<int,int>> E[N];
struct node { ll cost; mutable int cnt; };
bool operator <(node x,node y) { return x.cost<y.cost; }
bool operator >(node x,node y) { return x.cost>y.cost; }
__gnu_pbds::priority_queue<node,greater<node>> M[N],H[N];
void dfs(int u,int fa,ll dep) {
H[u].push((node){dep,x[u]}),M[u].push((node){dep-inf,y[u]});
for (auto t:E[u]) {
int v=t.first,w=t.second; if (v==fa) continue;
dfs(v,u,dep+w); H[u].join(H[v]),M[u].join(M[v]);
}
while (!M[u].empty()&&!H[u].empty()) {
auto m=M[u].top(),h=H[u].top();
ll cost=m.cost+h.cost-2*dep; int f=min(m.cnt,h.cnt);
if (cost>=0) break;
ans+=cost*f;
M[u].top().cnt-=f; if (!M[u].top().cnt) M[u].pop();
H[u].top().cnt-=f; if (!H[u].top().cnt) H[u].pop();
M[u].push((node){-cost+m.cost,f});
H[u].push((node){-cost+h.cost,f});
}
}
int main() {
n=read(); int s=0;
for (int i=1;i<n;++i) {
int u=read(),v=read(),w=read();
E[u].emplace_back(mp(v,w)),E[v].emplace_back(mp(u,w));
}
for (int i=1;i<=n;++i) x[i]=read(),y[i]=read(),s+=y[i];
dfs(1,0,0);
printf("%lld\n",ans+1ll*s*inf);
return 0;
}