LOJ

分析

考虑反着看整个过程,那么相当于每个点有 $y_i$ 只老鼠和 $x_i$ 个洞,所有老鼠都必须进洞。

$u,v$ 匹配的代价是 $dep_u+dep_v-2dep_{\operatorname{LCA}(u,v)}$。因此我们开两个堆维护子树中老鼠和洞的 $dep$,每次把儿子的堆合并,然后取堆顶匹配,再把反悔操作加入堆中即可。

然而这样并不能保证所有老鼠都进洞,所以我们把老鼠的 $dep$ 减去 $+\infty$,最后再把答案加上 $+\infty\times\sum y_i$ 即可。

这样子当匹配的老鼠和洞来自同一棵子树答案可能会不对,但是这样子一定不会是最优解,所以不会有问题。

因为一对匹配的老鼠和洞不会同时反悔,所以只会操作 $X=\sum x_i$ 次,所以时间复杂度 $\mathcal{O}(X\log X)$。可以记一个 pair 表示代价和出现次数从而做到 $\mathcal{O}(n\log n)$。

代码

因为懒所以用了 __gnu_pbds::priority_queue

// ====================================
//   author: M_sea
//   website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#include <ext/pb_ds/priority_queue.hpp>
#define file(x) freopen(#x".in","r",stdin); freopen(#x".out","w",stdout)
#define mp make_pair
using namespace std;
typedef long long ll;

int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int N=250000+10;
const ll inf=1e12;

int n,x[N],y[N]; ll ans=0;
vector<pair<int,int>> E[N];
struct node { ll cost; mutable int cnt; };
bool operator <(node x,node y) { return x.cost<y.cost; }
bool operator >(node x,node y) { return x.cost>y.cost; }
__gnu_pbds::priority_queue<node,greater<node>> M[N],H[N];

void dfs(int u,int fa,ll dep) {
    H[u].push((node){dep,x[u]}),M[u].push((node){dep-inf,y[u]});
    for (auto t:E[u]) {
        int v=t.first,w=t.second; if (v==fa) continue;
        dfs(v,u,dep+w); H[u].join(H[v]),M[u].join(M[v]);
    }
    while (!M[u].empty()&&!H[u].empty()) {
        auto m=M[u].top(),h=H[u].top();
        ll cost=m.cost+h.cost-2*dep; int f=min(m.cnt,h.cnt);
        if (cost>=0) break;
        ans+=cost*f;
        M[u].top().cnt-=f; if (!M[u].top().cnt) M[u].pop();
        H[u].top().cnt-=f; if (!H[u].top().cnt) H[u].pop();
        M[u].push((node){-cost+m.cost,f});
        H[u].push((node){-cost+h.cost,f});
    }
}

int main() {
    n=read(); int s=0;
    for (int i=1;i<n;++i) {
        int u=read(),v=read(),w=read();
        E[u].emplace_back(mp(v,w)),E[v].emplace_back(mp(u,w));
    }
    for (int i=1;i<=n;++i) x[i]=read(),y[i]=read(),s+=y[i];
    dfs(1,0,0);
    printf("%lld\n",ans+1ll*s*inf);
    return 0;
}
最后修改:2020 年 08 月 13 日 09 : 38 AM