Luogu

LOJ

分析

为了方便,将陷阱作为根节点。那么冷静分析一下可以得到一些结论(没有证明)

  • 当老鼠走到一个叶节点时,它将被困住。
  • 当老鼠被困住时,将其它岔路口堵住,然后将到根的路径擦干净是最优的。

设 $f_u$ 表示老鼠进入 $u$ 子树后回到 $u$ 的最小操作次数。显然老鼠会找一个 $f$ 最大的子树进去,所以堵住 $f$ 最大的子树最优,此时老鼠会进入 $f$ 次大的子树,于是有
$$
f_u=\operatorname{2ndmax}\left\{f_v\right\}+deg_u-[u\neq t]
$$
后面的东西可以理解成儿子数。

考虑整个过程中老鼠的移动,可以发现会向上走若干步(可能为 $0$),然后一路向下。

设 $g_u$ 表示老鼠第一次向下走进入 $u$ 子树后回到根的最小操作次数,$sum_u$ 为 $u$ 到根的岔路数,则有
$$
g_u=f_u+sum_{fa_u}-[fa_u=m]
$$
这里减去 $[fa_u=m]$ 的原因是老鼠下来时就堵了一条边。

考虑二分答案 $mid$。那么老鼠会想办法进入一个 $g_u>mid$ 的子树,如果没有才会往上走。于是我们需要把所有这样的子树封上,从下往上模拟一下即可。

代码

// ====================================
//   author: M_sea
//   website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin); freopen(#x".out","w",stdout)
using namespace std;
typedef long long ll;

int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int N=1000000+10;

int n,t,s;
vector<int> E[N];

int fa[N],f[N];
void dfs(int u,int p) {
    fa[u]=p; int mx=0,mx2=0;
    for (int v:E[u]) {
        if (v==p) continue;
        dfs(v,u);
        if (f[v]>=mx) mx2=mx,mx=f[v];
        else if (f[v]>=mx2) mx2=f[v];
    }
    f[u]=mx2+E[u].size()-(u!=t);
}

int chain[N],sum[N],top=0;
vector<int> ch[N];
bool check(int mid) {
    for (int i=1,r=0;i<top;++i) {
        ++r; int now=0;
        for (int j:ch[i])
            if (j-now>mid) {
                if (!r||!mid) return 0;
                --r,--mid,++now;
            }
    }
    return 1;
}

int main() {
    n=read(),t=read(),s=read();
    if (s==t) { puts("0"); return 0; }
    for (int i=1;i<n;++i) {
        int u=read(),v=read();
        E[u].emplace_back(v),E[v].emplace_back(u);
    }
    dfs(t,0);
    for (int i=s;i;i=fa[i]) chain[++top]=i;
    for (int i=top-1;i;--i) sum[i]=sum[i+1]+E[chain[i]].size()-1-(i!=1);
    for (int i=1;i<=top;++i) 
        for (int v:E[chain[i]])
            if (v!=chain[i-1]&&v!=chain[i+1]) ch[i].emplace_back(f[v]+sum[i]);
    int L=f[s],R=n<<1;
    while (L<R) {
        int mid=(L+R)>>1;
        if (check(mid)) R=mid;
        else L=mid+1;
    }
    printf("%d\n",L);
    return 0;
}
最后修改:2020 年 06 月 10 日 10 : 32 PM