Luogu

LOJ

分析

首先考虑原树为外向树的情况。此时 $u$ 子树中所有点选择时间都应比 $u$ 晚。

考虑这个概率。设 $u$ 子树 $w_i$ 的和为 $s$,整棵树 $w_i$ 的和为 $S$,则 $u$ 子树中所有点选择时间都比 $u$ 晚的概率为
$$
\frac{w_i}{S}\sum_{i=0}^{+\infty}\left(\frac{S-s}{S}\right)^i=\frac{w_i}{s}
$$
这个概率只和子树有关,因此可以设 $dp_{i,j}$ 表示以 $i$ 为根的子树 $w_i$ 和为 $j$ 的概率,转移即为树形背包。

现在有内向边,可以考虑容斥,即对于每条内向边包含它的方案数等于忽略它的方案数减去它为正向边时的方案数。前者不难计算,后者套用上面的 DP 即可。

代码

// ===================================
//   author: M_sea
//   website: https://m-sea-blog.com/
// ===================================
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int N=3000+10;
const int mod=998244353;
int qpow(int a,int b) { int c=1;
    for (;b;b>>=1,a=1ll*a*a%mod) if (b&1) c=1ll*c*a%mod;
    return c;
}

int n,a[N][4];

struct edge { int v,w,nxt; } e[N<<1];
int head[N];
void addEdge(int u,int v,int w) {
    static int cnt=0;
    e[++cnt]=(edge){v,w,head[u]},head[u]=cnt;
}

int sz[N],dp[N][N],tp[N];
void dfs(int u,int f) {
    sz[u]=1;
    dp[u][1]=a[u][1],dp[u][2]=2ll*a[u][2]%mod,dp[u][3]=3ll*a[u][3]%mod;
    for (int i=head[u];i;i=e[i].nxt) {
        int v=e[i].v; if (v==f) continue;
        dfs(v,u);
        for (int j=1;j<=(sz[u]+sz[v])*3;++j) tp[j]=0;
        if (!e[i].w) {
            for (int j=1;j<=sz[u]*3;++j)
                for (int k=1;k<=sz[v]*3;++k)
                    tp[j+k]=(tp[j+k]+1ll*dp[u][j]*dp[v][k])%mod;
        } else {
            int s=0;
            for (int j=1;j<=sz[v]*3;++j) s=(s+dp[v][j])%mod;
            for (int j=1;j<=sz[u]*3;++j) tp[j]=1ll*dp[u][j]*s%mod;
            for (int j=1;j<=sz[u]*3;++j)
                for (int k=1;k<=sz[v]*3;++k)
                    tp[j+k]=(tp[j+k]-1ll*dp[u][j]*dp[v][k]%mod+mod)%mod;
        }
        sz[u]+=sz[v];
        for (int j=1;j<=sz[u]*3;++j) dp[u][j]=tp[j];
    }
    for (int i=1;i<=sz[u]*3;++i) dp[u][i]=1ll*dp[u][i]*qpow(i,mod-2)%mod;
}

int main() {
    n=read();
    for (int i=1;i<=n;++i) {
        int x=read(),y=read(),z=read(),s=x+y+z;
        a[i][1]=1ll*x*qpow(s,mod-2)%mod;
        a[i][2]=1ll*y*qpow(s,mod-2)%mod;
        a[i][3]=1ll*z*qpow(s,mod-2)%mod;
    }
    for (int i=1;i<n;++i) {
        int u=read(),v=read();
        addEdge(u,v,0),addEdge(v,u,1);
    }
    dfs(1,0);
    int ans=0;
    for (int i=1;i<=n*3;++i) ans=(ans+dp[1][i])%mod;
    printf("%d\n",ans);
    return 0;
}
最后修改:2021 年 03 月 24 日 05 : 10 PM