分析
容斥
如果枚举一个点,然后算它被多少个连通块包含,显然是会算重的。
考虑树上连通块的一个性质:$|V|-|E|=1$。
于是如果我们用每个点满足条件的方案数减去每条边满足条件的方案数,就恰好可以得到答案。
DP
先考虑一个暴力 DP。
设 $f_{i,j}-1$ 表示以 $i$ 为根的子树中与 $i$ 距离不超过 $j$ 的连通块个数(这样设是为了方便下面的转移),$g_{i,j}$ 表示以 $i$ 为根的子树外并上 $i$ 中与 $i$ 距离不超过 $j$ 的连通块数。不难得到转移
$$
\begin{aligned}
f_{i,j}&=\prod_{v\in son_i}f_{v,j-1}+1\\
g_{i,j}&=\left(\prod_{v\in son_{fa_i},v\neq i}f_{v,j-2}\right)g_{fa_i,j-1}+1
\end{aligned}
$$
则一个点满足条件的方案数为 $\left[(f_{i,L}-1)g_{i,L}\right]^k$,一条边满足条件的方案数为 $\left[(f_{i,L-1}-1)(g_{i,L}-1)\right]^k$($i$ 是深的点)。
长链剖分优化 $f$
$f$ 的一维下标和深度有关,可以考虑用长链剖分优化。
按照套路,我们把轻儿子向重儿子合并。
因为 DP 状态的定义是 $\leq$,所以 $f_{v,len_v+1..+\infty}$ 都是有值的,如果我们对于这部分也暴力合并,那么复杂度就不对了。
如果 $f_{v,len_v}=0$,那么相当于将 $f_u$ 的一段后缀赋值为 $0$;如果 $f_{v,len_v}\neq 0$,那么相当于将 $f_u$ 的一段后缀乘上一个值。
后缀赋值可以通过记两个标记 $pos,num$ 表示 $f_{u,pos..+\infty}$ 都等于 $num$ 来解决;后缀乘直接做复杂度也不对,考虑替换为整体乘、前缀乘逆元,这样子复杂度就对了。这里可以打一个整体乘法标记解决。
最后还有一个 $+1$,可以再打一个加法标记解决。
综上,我们只需要对每个点打四个标记,即可只维护 $f_{u,1..len_u}$,来做到 $\mathcal{O(n)}$ 的复杂度。
优化 $g$
长链剖分
可以发现 $g_u$ 中只有 $g_{u,L-len_u..L}$ 是有用的,这些状态只有 $\mathcal{O}(len_u)$ 个,所以同样可以考虑长链剖分优化。
具体的,我们把父节点的 $g$ 继承给重儿子,轻儿子暴力转移。重儿子只需要把轻儿子合并进来即可,而计算轻儿子时主要的问题在于如何计算 $\prod_{v\in son_{fa_i},v\neq i}f_{v,j-2}$。
一个想法是将其写成 $\frac{f_{fa_i,j-1}-1}{f_{i,j-2}}$ 的形式,然而当 $f_{i,j-2}\equiv 0\pmod{p}$ 时会有问题。
另一个想法是求出 $f_{i,j-2}$ 的前缀积和后缀积。前缀积可以在枚举子节点时维护,但是后缀积就不是很好维护了。
可回退化
如果我们按照求 $f$ 时的逆序枚举子节点,然后每次将 $f$ 的版本回退一个,就可以求出后缀积了。
具体的,我们对每个节点开一个栈,存下每次修改前的值,撤销时就弹栈即可。
在求 $g$ 时同样有整体加、后缀乘、后缀赋值的需求,用之前优化 $f$ 时的那一套标记维护即可。
这样子求 $g$ 的复杂度也优化到了 $\mathcal{O}(n)$。
逆元
仔细思考一下上面的过程,你会发现在后缀乘时我们需要求逆元,这部分是带 $\log p$ 的。
但是仔细分析一下会发现,我们需要求逆元的一定只会是一些 $f_{u,len_u}$,而这个很好 $\mathcal{O}(n)$ 求出。
于是我们只需要预处理出 $f_{u,len_u}$,然后线性求逆元即可。在维护整体乘标记时,需要同时维护其逆元。
这样子我们就把除了算答案的快速幂外的所有部分优化成了 $\mathcal{O}(n)$,可以通过此题。
代码
码力不太够,基本是蒯的 zsy 的。
// ====================================
// author: M_sea
// website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define debug(...) fprintf(stderr,__VA_ARGS__)
using namespace std;
typedef long long ll;
int read() {
int X=0,w=1; char c=getchar();
while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
return X*w;
}
const int N=1000000+10;
const int mod=998244353;
int qpow(int a,int b) { int c=1;
for (;b;b>>=1,a=1ll*a*a%mod) if (b&1) c=1ll*c*a%mod;
return c;
}
int n,L,k,ans=0;
vector<int> E[N];
namespace T {
int len[N],hson[N],w[N],pre[N],inv[N],sta[N],top=0;
void dfs(int u,int fa) {
w[u]=1;
for (int v:E[u]) {
if (v==fa) continue;
dfs(v,u),w[u]=1ll*w[u]*w[v]%mod;
if (len[v]>len[hson[u]]) hson[u]=v;
}
w[u]=(w[u]+1)%mod,len[u]=len[hson[u]]+1;
}
void main() {
dfs(1,0);
for (int i=1;i<=n;++i) if (w[i]) sta[++top]=i;
for (int i=pre[0]=1;i<=top;++i) pre[i]=1ll*pre[i-1]*w[sta[i]]%mod;
int ipre=qpow(pre[top],mod-2);
for (int i=top;i;--i) inv[sta[i]]=1ll*pre[i-1]*ipre%mod,ipre=1ll*ipre*w[sta[i]]%mod;
}
}
using T::hson;
using T::len;
struct alice {
int add,mul,inv,pos,num;
alice(int a=0,int b=0,int c=0,int d=0,int e=0): add(a),mul(b),inv(c),pos(d),num(e) {}
};
namespace F {
alice t[N<<1];
int o[N<<2],*pos=o,*f[N<<1];
vector<pair<alice,vector<pair<int,int>>>> s[N];
void modify(int u,int i,int w) { f[u][i]=1ll*(w-t[u].add+mod)*t[u].inv%mod; }
int query(int u,int i) { return (1ll*(i<t[u].pos?f[u][i]:t[u].num)*t[u].mul+t[u].add)%mod; }
void merge(int u,int v,int l) {
alice ou=t[u]; vector<pair<int,int>> vc;
for (int i=1;i<=l;++i) {
vc.emplace_back(i,f[u][i]);
if (i==t[u].pos) f[u][i]=t[u].num,++t[u].pos;
modify(u,i,1ll*query(u,i)*query(v,i-1)%mod);
}
if (l<L) {
int s=query(v,l);
if (!s) t[u].pos=l+1,t[u].num=mod-1ll*t[u].add*t[u].inv%mod;
else {
vc.emplace_back(0,f[u][0]); int inv=T::inv[v];
for (int i=0;i<=l;++i) modify(u,i,1ll*query(u,i)*inv%mod);
t[u].add=1ll*t[u].add*s%mod,t[u].mul=1ll*t[u].mul*s%mod,t[u].inv=1ll*t[u].inv*inv%mod;
}
}
if (u<=n) s[u].emplace_back(ou,vc);
}
void dfs(int u,int fa) {
if (hson[u]) f[hson[u]]=f[u]+1,dfs(hson[u],u),t[u]=t[hson[u]];
else t[u]=alice(1,1,1,n,0);
modify(u,0,1);
for (int v:E[u]) {
if (v==fa||v==hson[u]) continue;
f[v]=pos,pos+=len[v]; dfs(v,u);
merge(u,v,min(len[v]-1,L));
}
++t[u].add;
}
void rollback(int u) {
t[u]=s[u].back().first;
for (auto i:s[u].back().second) f[u][i.first]=i.second;
s[u].pop_back();
}
void main() { f[1]=pos,pos+=len[1]; dfs(1,0); }
}
namespace G {
alice t[N];
int o[N<<1],*pos=o,*g[N];
void modify(int u,int i,int w) { g[u][i]=1ll*(w-t[u].add+mod)*t[u].inv%mod; }
int query(int u,int i) { return (1ll*(i<t[u].pos?g[u][i]:t[u].num)*t[u].mul+t[u].add)%mod; }
void dfs(int u,int fa) {
if (len[u]-L>=1) modify(u,len[u]-L-1,1);
ans=(ans+qpow(1ll*(F::query(u,min(len[u]-1,L))-1+mod)*query(u,len[u]-1)%mod,k))%mod;
if (fa) ans=(ans-qpow(1ll*(F::query(u,min(len[u]-1,L-1))-1+mod)*(query(u,len[u]-1)-1+mod)%mod,k)+mod)%mod;
if (!hson[u]) return;
int mxl=0;
for (int v:E[u]) if (v!=fa&&v!=hson[u]) mxl=max(mxl,len[v]);
mxl=min(mxl,L);
F::f[u+n]=F::pos,F::pos+=mxl+1,F::t[u+n]=alice(1,1,1,n,0); F::modify(u+n,0,1);
reverse(E[u].begin(),E[u].end());
for (int v:E[u]) {
if (v==fa||v==hson[u]) continue;
g[v]=pos,pos+=len[v]; F::rollback(u);
for (int i=max(len[v]-L-1,0);i<len[v];++i) {
if (i==len[v]-L-1) g[v][i]=query(u,len[u]-1-len[v]+i);
else g[v][i]=1ll*query(u,len[u]-1-len[v]+i)*F::query(u,min(len[u]-1,L-len[v]+i))%mod*F::query(u+n,min(mxl,L-len[v]+i))%mod;
}
F::merge(u+n,v,min(len[v]-1,L));
t[v]=alice(1,1,1,n,0),dfs(v,u);
}
int hs=hson[u]; g[hs]=g[u],t[hs]=t[u];
for (int i=max(len[hs]-L,0);i<len[hs]+mxl-L;++i) {
if (i==t[hs].pos) g[hs][i]=t[hs].num,++t[hs].pos;
modify(hs,i,1ll*query(hs,i)*F::query(u+n,L-len[hs]+i)%mod);
}
if (mxl<L) {
int s=1,inv=1;
for (int v:E[u])
if (v!=fa&&v!=hs) s=1ll*s*T::w[v]%mod,inv=1ll*inv*T::inv[v]%mod;
if (!s) t[hs].pos=len[hs]+mxl-L,t[hs].num=mod-1ll*t[hs].add*t[hs].inv%mod;
else {
for (int i=max(len[hs]-L,0);i<len[hs]+mxl-L;++i) modify(hs,i,1ll*query(hs,i)*inv%mod);
t[hs].add=1ll*t[hs].add*s%mod,t[hs].mul=1ll*t[hs].mul*s%mod,t[hs].inv=1ll*t[hs].inv*inv%mod;
}
}
++t[hs].add,dfs(hs,u);
}
void main() { g[1]=pos,pos+=len[1],t[1]=alice(1,1,1,n,0); dfs(1,0); }
}
int main() {
n=read(),L=read(),k=read();
for (int i=1;i<n;++i) {
int u=read(),v=read();
E[u].emplace_back(v),E[v].emplace_back(u);
}
T::main(),F::main(),G::main();
printf("%d\n",ans);
return 0;
}