Luogu

分析

答案显然是
$$
\frac{\sum_{i=1}^n\sum_{j=1}^m(a_i+b_j)^k}{nm}
$$
我们不管分母,只看分子。用二项式定理拆开得到
$$
\sum_{i=1}^n\sum_{j=1}^m\sum_{r=0}^k{k\choose r}a_i\!^rb_j\!^{k-r}
$$
稍微换一下形式得到
$$
\sum_{r=0}^k{k\choose r}\left(\sum_{i=1}^na_i\!^r\right)\left(\sum_{i=1}^mb_i\!^{k-r}\right)
$$
如果我们能够对每个 $r$ 快速求出 $\sum_{i=1}^n a_i\!^r$,那么就可以直接卷了。

这是一个经典问题(鸡贼说叫“等幂和”)。我们相当于要求生成函数
$$
A(x)=\sum_{r\geq 0}x^r\sum_{i=1}^n a_i\!^r
$$
的前若干项系数。

我们把 $A(x)$ 换一个形式写出来
$$
A(x)=\sum_{i=1}^n\frac{1}{1-a_ix}
$$
根据这类问题的方法,设
$$
G(x)=\sum_{i=1}^n\frac{-a_i}{1-a_ix}
$$
那么有
$$
A(x)=n-xG(x)
$$
考虑怎么算 $G(x)$。注意到
$$
\ln(1-a_ix)'=\frac{-a_i}{1-a_ix}
$$
于是可以想到把 $G(x)$ 化成
$$
G(x)=\sum_{i=1}^n\ln(1-a_ix)'
$$
把 $\sum$ 丢到 $\ln$ 里去得到
$$
G(x)=\ln\left(\prod_{i=1}^n(1-a_ix)\right)'
$$
直接分治把所有 $1-a_ix$ 卷起来,然后多项式 $\ln$ + 多项式求导即可求出 $G(x)$,再推回 $A(x)$ 后卷积即可。

代码

不想写 vector,于是从鱼那里蒯了一个分治 /kel

// ====================================
//   author: M_sea
//   website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin); freopen(#x".out","w",stdout)
#define debug(...) fprintf(stderr,__VA_ARGS__)
using namespace std;
typedef long long ll;

int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int N=400000+10;
const int mod=998244353;
int qpow(int a,int b) { int c=1;
    for (;b;b>>=1,a=1ll*a*a%mod) if (b&1) c=1ll*c*a%mod;
    return c;
}

int r[N];
void NTT(int *A,int n,int op) {
    for (int i=0;i<n;++i) if (i<r[i]) swap(A[i],A[r[i]]);
    for (int i=1;i<n;i<<=1) {
        int rot=qpow(op==1?3:332748118,(mod-1)/(i<<1));
        for (int j=0;j<n;j+=i<<1)
            for (int k=0,w=1;k<i;++k,w=1ll*w*rot%mod) {
                int x=A[j+k],y=1ll*w*A[j+k+i]%mod;
                A[j+k]=(x+y)%mod,A[j+k+i]=(x-y+mod)%mod;
            }
    }
    if (op==-1) { int inv=qpow(n,mod-2);
        for (int i=0;i<n;++i) A[i]=1ll*A[i]*inv%mod;
    }
}
int NTT_init(int n) {
    int lim=1,l=-1;
    for (;lim<n;lim<<=1,++l);
    for (int i=0;i<lim;++i) r[i]=(r[i>>1]>>1)|((i&1)<<l);
    return lim;
}

void polyinv(int *F,int *G,int n) {
    static int A[N],B[N];
    if (n==1) { G[0]=qpow(F[0],mod-2); return; }
    polyinv(F,G,n>>1);
    for (int i=0;i<n;++i) A[i]=F[i],B[i]=G[i];
    int lim=NTT_init(n<<1);
    NTT(A,lim,1),NTT(B,lim,1);
    for (int i=0;i<lim;++i) A[i]=1ll*A[i]*B[i]%mod*B[i]%mod;
    NTT(A,lim,-1);
    for (int i=0;i<n;++i) G[i]=(2ll*G[i]-A[i]+mod)%mod;
    for (int i=0;i<lim;++i) A[i]=B[i]=0;
}

void polyderi(int *F,int *G,int n) {
    for (int i=1;i<n;++i) G[i-1]=1ll*F[i]*i%mod;
    G[n-1]=0;
}
void polyinte(int *F,int *G,int n) {
    for (int i=1;i<n;++i) G[i]=1ll*F[i-1]*qpow(i,mod-2)%mod;
    G[0]=0;
}

void polyln(int *F,int *G,int n) {
    static int A[N],B[N];
    polyderi(F,A,n),polyinv(F,B,n);
    int lim=NTT_init(n<<1);
    NTT(A,lim,1),NTT(B,lim,1);
    for (int i=0;i<lim;++i) A[i]=1ll*A[i]*B[i]%mod;
    NTT(A,lim,-1);
    polyinte(A,G,n);
    for (int i=0;i<lim;++i) A[i]=B[i]=0;
}

#define ls (o<<1)
#define rs (o<<1|1)
int len[N];
void divide(int o,int l,int r,int *F,int *G) {
    len[o]=r-l+1;
    if (l==r) { G[0]=1,G[1]=mod-F[l]; return; }
    int L[N],R[N]; int mid=(l+r)>>1;
    divide(ls,l,mid,F,L),divide(rs,mid+1,r,F,R);
    int lim=NTT_init(len[o]+1);
    for (int i=len[ls]+1;i<lim;++i) L[i]=0;
    for (int i=len[rs]+1;i<lim;++i) R[i]=0;
    NTT(L,lim,1),NTT(R,lim,1);
    for (int i=0;i<lim;++i) G[i]=1ll*L[i]*R[i]%mod;
    NTT(G,lim,-1);
    for (int i=len[o]+1;i<lim;++i) G[i]=0;
}
#undef ls
#undef rs

int n,m,k,a[N],b[N],fac[N],ifac[N],F[N],G[N];

int main() {
    n=read(),m=read();
    for (int i=1;i<=n;++i) a[i]=read();
    for (int i=1;i<=m;++i) b[i]=read();
    divide(1,1,n,a,F),divide(1,1,m,b,G);
    k=read(); int l=max({n,m,k})+1;
    int lim=1; for (;lim<=l;lim<<=1);
    polyln(F,F,lim),polyderi(F,F,lim);
    polyln(G,G,lim),polyderi(G,G,lim);
    for (int i=lim-1;i;--i) F[i]=mod-F[i-1]; F[0]=n;
    for (int i=lim-1;i;--i) G[i]=mod-G[i-1]; G[0]=m;
    fac[0]=1;
    for(int i=1;i<lim;++i) fac[i]=1ll*fac[i-1]*i%mod;
    ifac[lim-1]=qpow(fac[lim-1],mod-2);
    for (int i=lim-1;i;--i) ifac[i-1]=1ll*ifac[i]*i%mod;
    for (int i=0;i<lim;++i) F[i]=1ll*F[i]*ifac[i]%mod;
    for (int i=0;i<lim;++i) G[i]=1ll*G[i]*ifac[i]%mod;
    lim=NTT_init(lim<<1);
    NTT(F,lim,1),NTT(G,lim,1);
    for (int i=0;i<lim;++i) F[i]=1ll*F[i]*G[i]%mod;
    NTT(F,lim,-1);
    int inv=qpow(1ll*n*m%mod,mod-2);
    for (int i=1;i<=k;++i) printf("%d\n",1ll*F[i]*fac[i]%mod*inv%mod);
    return 0;
}
最后修改:2020 年 10 月 26 日 09 : 14 AM