LOJ

分析

容易想到容斥,计算至少包含 $i$ 个魔术对的序列数 $g_i$。根据二项式反演不难得到
$$
ans_k=\sum_{i=k}^{n-1}(-1)^{i-k}{i\choose k}g_i
$$
发现本质不同这个条件不太好做。可以忽略掉这个条件即给每张卡分配一个标号,只需要最后把答案除掉 $\prod a_i!$ 即可。

考虑一个 DP,即设 $dp_{i,j}$ 表示前 $i$ 类卡、至少有 $j$ 个魔术对的方案数。设 $f_{i,j}$ 表示第 $i$ 类卡至少有 $j$ 个魔术对的方案书,则转移就是 $dp_{i-1}$ 和 $f_i$ 卷积。

考虑如何求 $f_{i,j}$。可以先选出 $a_i-j$ 张卡,然后将剩下的 $j$ 张卡插入到某张卡的右边,这样子一定产生至少 $j$ 个魔术对。而第一张卡有 $a_i-j$ 种选法,第二张卡有 $a_i-j+1$ 种,……第 $j$ 张卡有 $a_i-1$ 种。从而不难得到
$$
f_{i,j}={a_i\choose j}\frac{(a_i-1)!}{(a_i-j-1)!}
$$
分治求出所有 $f_i$ 的卷积即可得到 $dp_m$。

然而 $dp_{m,i}$ 并不等于 $g_i$,因为我们还可以任意排列不构成魔术对的 $n-i$ 张牌,所以还需要乘上一个 $(n-i)!$。

最后二项式反演一下即可得到答案。记得把答案除掉 $\prod a_i!$。

代码

// ====================================
//   author: M_sea
//   website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin); freopen(#x".out","w",stdout)
using namespace std;
typedef long long ll;

int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int N=400000+10;
const int mod=998244353;
int qpow(int a,int b) { int c=1;
    for (;b;b>>=1,a=1ll*a*a%mod) if (b&1) c=1ll*c*a%mod;
    return c;
}

int m,n,k,a[N];

int fac[N],ifac[N];
int C(int n,int m) {
    return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}

vector<int> F[N],G;

int r[N];
void NTT(vector<int>& A,int n,int op) {
    for (int i=0;i<n;++i) if (i<r[i]) swap(A[i],A[r[i]]);
    for (int i=1;i<n;i<<=1) {
        int rot=qpow(op==1?3:(mod+1)/3,(mod-1)/(i<<1));
        for (int j=0;j<n;j+=i<<1)
            for (int k=0,w=1;k<i;++k,w=1ll*w*rot%mod) {
                int x=A[j+k],y=1ll*w*A[j+k+i]%mod;
                A[j+k]=(x+y)%mod,A[j+k+i]=(x-y+mod)%mod;
            }
    }
    if (op==-1) { int inv=qpow(n,mod-2);
        for (int i=0;i<n;++i) A[i]=1ll*A[i]*inv%mod;
    }
}

vector<int> solve(int L,int R) {
    if (L==R) return F[L];
    int mid=(L+R)>>1;
    vector<int> A=solve(L,mid),B=solve(mid+1,R);
    int lim=1,l=-1;
    for (;lim<=A.size()+B.size();lim<<=1,++l);
    for (int i=0;i<lim;++i) r[i]=(r[i>>1]>>1)|((i&1)<<l);
    A.resize(lim),B.resize(lim);
    NTT(A,lim,1),NTT(B,lim,1);
    for (int i=0;i<lim;++i) A[i]=1ll*A[i]*B[i]%mod;
    NTT(A,lim,-1);
    while (A.size()&&!A.back()) A.pop_back();
    return A;
}

int main() {
    m=read(),n=read(),k=read();
    for (int i=1;i<=m;++i) a[i]=read();
    fac[0]=1;
    for (int i=1;i<=n;++i) fac[i]=1ll*fac[i-1]*i%mod;
    ifac[n]=qpow(fac[n],mod-2);
    for (int i=n;i;--i) ifac[i-1]=1ll*ifac[i]*i%mod;
    for (int i=1;i<=m;++i) {
        F[i].resize(a[i]);
        for (int j=0;j<a[i];++j)
            F[i][j]=1ll*C(a[i],j)*fac[a[i]-1]%mod*ifac[a[i]-j-1]%mod;
    }
    G=solve(1,m);
    for (int i=0;i<G.size();++i) G[i]=1ll*G[i]*fac[n-i]%mod;
    int ans=0;
    for (int i=k;i<G.size();++i) {
        if ((i-k)&1) ans=(ans-1ll*C(i,k)*G[i]%mod+mod)%mod;
        else ans=(ans+1ll*C(i,k)*G[i])%mod;
    }
    for (int i=1;i<=m;++i) ans=1ll*ans*ifac[a[i]]%mod;
    printf("%d\n",ans);
    return 0;
}
最后修改:2020 年 06 月 02 日 09 : 40 PM