Luogu

LOJ

分析

规定根节点的深度为 $1$。

考虑 DP。设 $dp_{u,i}$ 表示以 $u$ 为根的子树,下端在子树中的未被覆盖的链的上端的最深深度为 $i$ 时的方案数($i=0$ 表示全部被覆盖)。这个“最深”的好处在于我们覆盖了深的就一定会覆盖浅的,便于计数。

考虑转移,每次把一棵子树合并进来,考虑子树的父边填 $1$ 还是 $0$ 可以得到

$$ dp_{u,i}\leftarrow\sum_{j=0}^{dep_u}dp_{u,i}\times dp_{v,j}+\sum_{j=0}^idp_{u,i}\times dp_{v,j}+\sum_{j=0}^{i-1}dp_{u,j}\times dp_{v,i} $$

最后一项上界是 $i-1$ 而不是 $i$ 的原因是前面已经算过了。

设 $s_{u,i}=\sum_{j=0}^i dp_{u,j}$,则上式可以改写为

$$ dp_{u,i}\leftarrow dp_{u,i}\times(s_{v,dep_u}+s_{v,i})+dp_{v,i}\times s_{u,i-1} $$

可以想到整体 DP,用线段树维护每个节点的 DP 值,则我们需要考虑如何合并。注意到上面这个式子中只有 $s_{v,dep_u}$ 是和下标无关的,所以我们先求一遍这个东西,然后在线段树合并时维护 $s_{v,i}$ 和 $s_{u,i-1}$ 即可,区间乘打一个乘法标记即可。具体实现可以参考代码。

代码

// ====================================
//   author: M_sea
//   website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin); freopen(#x".out","w",stdout)
using namespace std;
typedef long long ll;

int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int N=500000+10;
const int mod=998244353;

int n,m,dep[N];
vector<int> E[N],top[N];

#define ls(o) t[o].ls
#define rs(o) t[o].rs
int rt[N],tot=0;
struct node { int ls,rs,sumv,mulv; } t[N*30];
void pushup(int o) { t[o].sumv=(t[ls(o)].sumv+t[rs(o)].sumv)%mod; }
void pushdown(int o) {
    if (t[o].mulv!=1) {
        t[ls(o)].mulv=1ll*t[ls(o)].mulv*t[o].mulv%mod;
        t[ls(o)].sumv=1ll*t[ls(o)].sumv*t[o].mulv%mod;
        t[rs(o)].mulv=1ll*t[rs(o)].mulv*t[o].mulv%mod;
        t[rs(o)].sumv=1ll*t[rs(o)].sumv*t[o].mulv%mod;
        t[o].mulv=1;
    }
}
void modify(int& o,int l,int r,int p,int w) {
    if (!o) o=++tot,t[o].mulv=1;
    if (l==r) { t[o].mulv=1,t[o].sumv=w; return; }
    int mid=(l+r)>>1; pushdown(o);
    if (p<=mid) modify(ls(o),l,mid,p,w);
    else modify(rs(o),mid+1,r,p,w);
    pushup(o);
}
int query(int o,int l,int r,int ql,int qr) {
    if (!o) return 0;
    if (ql<=l&&r<=qr) return t[o].sumv;
    int mid=(l+r)>>1,res=0; pushdown(o);
    if (ql<=mid) res=(res+query(ls(o),l,mid,ql,qr))%mod;
    if (qr>mid) res=(res+query(rs(o),mid+1,r,ql,qr))%mod;
    pushup(o); return res;
}
int merge(int x,int y,int l,int r,int& su,int& sv) {
    if (!x&&!y) return 0;
    if (!x) {
        sv=(sv+t[y].sumv)%mod;
        t[y].mulv=1ll*t[y].mulv*su%mod;
        t[y].sumv=1ll*t[y].sumv*su%mod;
        return y;
    }
    if (!y) {
        su=(su+t[x].sumv)%mod;
        t[x].mulv=1ll*t[x].mulv*sv%mod;
        t[x].sumv=1ll*t[x].sumv*sv%mod;
        return x;
    }
    if (l==r) {
        int tx=t[x].sumv,ty=t[y].sumv;
        sv=(sv+ty)%mod;
        t[x].sumv=(1ll*t[x].sumv*sv+1ll*t[y].sumv*su)%mod;
        su=(su+tx)%mod;
        return x;
    }
    int mid=(l+r)>>1; pushdown(x),pushdown(y);
    ls(x)=merge(ls(x),ls(y),l,mid,su,sv);
    rs(x)=merge(rs(x),rs(y),mid+1,r,su,sv);
    pushup(x); return x;
}

void dfs(int u,int fa) {
    dep[u]=dep[fa]+1; int d=0;
    for (int i:top[u]) d=max(d,dep[i]);
    modify(rt[u],0,n,d,1);
    for (int v:E[u]) {
        if (v==fa) continue;
        dfs(v,u);
        int su=0,sv=query(rt[v],0,n,0,dep[u]);
        rt[u]=merge(rt[u],rt[v],0,n,su,sv);
    }
}

int main() {
    n=read();
    for (int i=1;i<n;++i) {
        int u=read(),v=read();
        E[u].emplace_back(v),E[v].emplace_back(u);
    }
    m=read();
    for (int i=1;i<=m;++i) {
        int u=read(),v=read();
        top[v].emplace_back(u);
    }
    dfs(1,0);
    printf("%d\n",query(rt[1],0,n,0,0));
    return 0;
}
最后修改:2020 年 08 月 20 日 07 : 52 PM