分析
完全图的生成树个数为 $n^{n-2}$ ,每条边会出现 $2n^{n-3}$ 次。
容易得到答案为
$$
\frac{2n^{n-3}}{n^{n-2}}\sum_{i=1}^{n-1}\sum_{j=i+1}^n(i+j)^k
$$
前面的等于 $\frac{2}{n}$ ,考虑怎么计算 $\sum_{i=1}^{n-1}\sum_{j=i+1}^n(i+j)^k$ 。
设
$$
a_n=\sum_{i=1}^{n-1}\sum_{j=i+1}^n(i+j)^k
$$
可以发现 $a_n$ 是一个关于 $n$ 的 $k+2$ 次多项式,我们只考虑求前 $k+3$ 项,后面的可以拉格朗日插值求解。
差分一下可以得到
$$
a_n-a_{n-1}=\sum_{i=n+1}^{2n-1}i^k
$$
考虑怎么计算 $\sum_{i=n+1}^{2n-1}i^k$ 。可以发现 $i^k$ 是完全积性函数,直接线性筛后求前缀和即可。
整理一下思路:
- 首先线性筛求出 $i^k$ ,然后求出它的前缀和。此部分时间复杂度为 $O(k)$。
- 然后利用算出的前缀和,递推计算出 $a_n$ 的前 $k+3$ 项。此部分时间复杂度也为 $O(k)$。
- 最后使用拉格朗日插值求出 $a_n$ ,然后计算答案即可。此部分时间复杂度仍为 $O(k)$。
总时间复杂度 $O(k)$。
代码
// ===================================
// author: M_sea
// website: http://m-sea-blog.com/
// ===================================
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
#define re register
using namespace std;
inline int read() {
int X=0,w=1; char c=getchar();
while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
return X*w;
}
const int N=20000006+10;
const int mod=998244353;
inline int qpow(int a,int b) {
int c=1;
for (;b;b>>=1,a=1ll*a*a%mod)
if (b&1) c=1ll*c*a%mod;
return c;
}
int n,k;
int a[N];
int inv[N];
int p[N],v[N],s[N],cnt=0;
int pre[N],suf[N];
inline void init(int n) {
inv[0]=inv[1]=1;
for (re int i=2;i<=n;++i) inv[i]=1ll*inv[mod%i]*(mod-mod/i)%mod;
}
inline void sieve(int n) {
s[1]=1;
for (re int i=2;i<=n;++i) {
if (!v[i]) p[++cnt]=i,s[i]=qpow(i,k);
for (re int j=1;j<=cnt&&i*p[j]<=n;++j) {
v[i*p[j]]=1;
s[i*p[j]]=1ll*s[i]*s[p[j]]%mod;
if (i%p[j]==0) break;
}
}
for (re int i=1;i<=n;++i) s[i]=(s[i]+s[i-1])%mod;
}
inline int calc(int d) {
if (d<=k+3) return a[d];
pre[0]=suf[k+3]=1; int res=0;
for (re int i=1;i<=k+3;++i)
pre[i]=1ll*pre[i-1]*(d-i+1+mod)%mod*inv[i]%mod;
for (re int i=k+2;~i;--i)
suf[i]=1ll*suf[i+1]*(d-i-1+mod)%mod*(mod-inv[k+3-i])%mod;
for (re int i=0;i<=k+3;++i)
res=(res+1ll*pre[i]*suf[i]%mod*a[i])%mod;
return res;
}
int main() {
n=read(),k=read();
init((k+3)<<1),sieve((k+3)<<1);
for (re int i=2;i<=k+3;++i)
a[i]=(1ll*a[i-1]+s[(i<<1)-1]-s[i]+mod)%mod;
printf("%d\n",2ll*calc(n)*qpow(n,mod-2)%mod);
return 0;
}