分析
规定根节点的深度为 $1$。
考虑 DP。设 $dp_{u,i}$ 表示以 $u$ 为根的子树,下端在子树中的未被覆盖的链的上端的最深深度为 $i$ 时的方案数($i=0$ 表示全部被覆盖)。这个“最深”的好处在于我们覆盖了深的就一定会覆盖浅的,便于计数。
考虑转移,每次把一棵子树合并进来,考虑子树的父边填 $1$ 还是 $0$ 可以得到
$$
dp_{u,i}\leftarrow\sum_{j=0}^{dep_u}dp_{u,i}\times dp_{v,j}+\sum_{j=0}^idp_{u,i}\times dp_{v,j}+\sum_{j=0}^{i-1}dp_{u,j}\times dp_{v,i}
$$
最后一项上界是 $i-1$ 而不是 $i$ 的原因是前面已经算过了。
设 $s_{u,i}=\sum_{j=0}^i dp_{u,j}$,则上式可以改写为
$$
dp_{u,i}\leftarrow dp_{u,i}\times(s_{v,dep_u}+s_{v,i})+dp_{v,i}\times s_{u,i-1}
$$
可以想到整体 DP,用线段树维护每个节点的 DP 值,则我们需要考虑如何合并。注意到上面这个式子中只有 $s_{v,dep_u}$ 是和下标无关的,所以我们先求一遍这个东西,然后在线段树合并时维护 $s_{v,i}$ 和 $s_{u,i-1}$ 即可,区间乘打一个乘法标记即可。具体实现可以参考代码。
代码
// ====================================
// author: M_sea
// website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin); freopen(#x".out","w",stdout)
using namespace std;
typedef long long ll;
int read() {
int X=0,w=1; char c=getchar();
while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
return X*w;
}
const int N=500000+10;
const int mod=998244353;
int n,m,dep[N];
vector<int> E[N],top[N];
#define ls(o) t[o].ls
#define rs(o) t[o].rs
int rt[N],tot=0;
struct node { int ls,rs,sumv,mulv; } t[N*30];
void pushup(int o) { t[o].sumv=(t[ls(o)].sumv+t[rs(o)].sumv)%mod; }
void pushdown(int o) {
if (t[o].mulv!=1) {
t[ls(o)].mulv=1ll*t[ls(o)].mulv*t[o].mulv%mod;
t[ls(o)].sumv=1ll*t[ls(o)].sumv*t[o].mulv%mod;
t[rs(o)].mulv=1ll*t[rs(o)].mulv*t[o].mulv%mod;
t[rs(o)].sumv=1ll*t[rs(o)].sumv*t[o].mulv%mod;
t[o].mulv=1;
}
}
void modify(int& o,int l,int r,int p,int w) {
if (!o) o=++tot,t[o].mulv=1;
if (l==r) { t[o].mulv=1,t[o].sumv=w; return; }
int mid=(l+r)>>1; pushdown(o);
if (p<=mid) modify(ls(o),l,mid,p,w);
else modify(rs(o),mid+1,r,p,w);
pushup(o);
}
int query(int o,int l,int r,int ql,int qr) {
if (!o) return 0;
if (ql<=l&&r<=qr) return t[o].sumv;
int mid=(l+r)>>1,res=0; pushdown(o);
if (ql<=mid) res=(res+query(ls(o),l,mid,ql,qr))%mod;
if (qr>mid) res=(res+query(rs(o),mid+1,r,ql,qr))%mod;
pushup(o); return res;
}
int merge(int x,int y,int l,int r,int& su,int& sv) {
if (!x&&!y) return 0;
if (!x) {
sv=(sv+t[y].sumv)%mod;
t[y].mulv=1ll*t[y].mulv*su%mod;
t[y].sumv=1ll*t[y].sumv*su%mod;
return y;
}
if (!y) {
su=(su+t[x].sumv)%mod;
t[x].mulv=1ll*t[x].mulv*sv%mod;
t[x].sumv=1ll*t[x].sumv*sv%mod;
return x;
}
if (l==r) {
int tx=t[x].sumv,ty=t[y].sumv;
sv=(sv+ty)%mod;
t[x].sumv=(1ll*t[x].sumv*sv+1ll*t[y].sumv*su)%mod;
su=(su+tx)%mod;
return x;
}
int mid=(l+r)>>1; pushdown(x),pushdown(y);
ls(x)=merge(ls(x),ls(y),l,mid,su,sv);
rs(x)=merge(rs(x),rs(y),mid+1,r,su,sv);
pushup(x); return x;
}
void dfs(int u,int fa) {
dep[u]=dep[fa]+1; int d=0;
for (int i:top[u]) d=max(d,dep[i]);
modify(rt[u],0,n,d,1);
for (int v:E[u]) {
if (v==fa) continue;
dfs(v,u);
int su=0,sv=query(rt[v],0,n,0,dep[u]);
rt[u]=merge(rt[u],rt[v],0,n,su,sv);
}
}
int main() {
n=read();
for (int i=1;i<n;++i) {
int u=read(),v=read();
E[u].emplace_back(v),E[v].emplace_back(u);
}
m=read();
for (int i=1;i<=m;++i) {
int u=read(),v=read();
top[v].emplace_back(u);
}
dfs(1,0);
printf("%d\n",query(rt[1],0,n,0,0));
return 0;
}