Luogu

LOJ

分析

考虑容斥,问题变为计算至少有 $i$ 组人讨论 cxk 的方案数。

首先我们需要选出 $i$ 个位置 $p_{1..i}$ 满足 $[p_j,p_j+3]$ 不交,容易算出方案数为 ${n-3i\choose i}$。

对于剩下的位置,枚举每种学生的个数,方案数是一个多重集排列数,式子是
$$
\sum_{\begin{matrix}x+y+z+w=n-4i,\\x\leq a-i,y\leq b-i,z\leq c-i,w\leq d-i\end{matrix}}\frac{(n-4i)!}{x!y!z!w!}
$$
可以看成四个生成函数卷起来,用 NTT 可以 $\mathcal{O}(n\log n)$ 计算。总时间复杂度 $\mathcal{O}(n^2\log n)$。

代码

// ====================================
//   author: M_sea
//   website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define debug(...) fprintf(stderr,__VA_ARGS__)
using namespace std;
typedef long long ll;

int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int N=4000+10;
const int mod=998244353;
int qpow(int a,int b) { int c=1;
    for (;b;b>>=1,a=1ll*a*a%mod) if (b&1) c=1ll*c*a%mod;
    return c;
}

int n,a,b,c,d;
int fac[N],ifac[N];

void init(int n) {
    fac[0]=1;
    for (int i=1;i<=n;++i) fac[i]=1ll*fac[i-1]*i%mod;
    ifac[n]=qpow(fac[n],mod-2);
    for (int i=n;i;--i) ifac[i-1]=1ll*ifac[i]*i%mod;
}
int C(int n,int m) {
    if (n<m) return 0;
    return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}

int r[N];
void NTT(int *A,int n,int op) {
    for (int i=0;i<n;++i) if (i<r[i]) swap(A[i],A[r[i]]);
    for (int i=1;i<n;i<<=1) {
        int rot=qpow(op==1?3:332748118,(mod-1)/(i<<1));
        for (int j=0;j<n;j+=i<<1)
            for (int k=0,w=1;k<i;++k,w=1ll*w*rot%mod) {
                int x=A[j+k],y=1ll*w*A[j+k+i]%mod;
                A[j+k]=(x+y)%mod,A[j+k+i]=(x-y+mod)%mod;
            }
    }
    if (op==-1) {
        int inv=qpow(n,mod-2);
        for (int i=0;i<n;++i) A[i]=1ll*A[i]*inv%mod;
    }
}

int calc(int n,int a,int b,int c,int d) {
    static int A[N],B[N],C[N],D[N];
    int lim=1,l=0;
    for (;lim<=a+b+c+d;lim<<=1,++l);
    for (int i=0;i<lim;++i) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    for (int i=0;i<lim;++i) {
        A[i]=i<=a?ifac[i]:0,B[i]=i<=b?ifac[i]:0;
        C[i]=i<=c?ifac[i]:0,D[i]=i<=d?ifac[i]:0;
    }
    NTT(A,lim,1),NTT(B,lim,1),NTT(C,lim,1),NTT(D,lim,1);
    for (int i=0;i<lim;++i) A[i]=1ll*A[i]*B[i]%mod*C[i]%mod*D[i]%mod;
    NTT(A,lim,-1);
    return A[n];
}

int main() {
    n=read(),a=read(),b=read(),c=read(),d=read();
    init(max({n,a,b,c,d}));
    int L=min({n/4,a,b,c,d}),ans=0;
    for (int i=0;i<=L;++i) {
        int w=1ll*C(n-3*i,i)*fac[n-4*i]%mod*calc(n-4*i,a-i,b-i,c-i,d-i)%mod;
        if (i&1) ans=(ans-w+mod)%mod;
        else ans=(ans+w)%mod;
    }
    printf("%d\n",ans);
    return 0;
}
最后修改:2021 年 01 月 24 日 05 : 28 PM