Luogu

BZOJ

LOJ

分析

考虑容斥,计算至少 $i$ 堆人讨论 cxk 的方案数。

这个方案数等于点 $i$ 堆人出来的方案数乘上剩下的人随便排的方案数。

前面那个东西很容易算,就是 $\dbinom{n-3i}i$ 。

考虑后面这个东西怎么算。

首先有可重集的排列公式: $\large\frac{n!}{a_1!a_2!...a_k!}$ 。

考虑枚举每种的人数,那么这一部分的方案数就是

$$\large\displaystyle\sum_{i=0}^{a-i}\sum_{j=0}^{b-i}\sum_{k=0}^{c-i}\sum_{l=0}^{d-i}[a+b+c+d=n-4i]\frac{(n-4i)!}{i!j!k!l!}$$

把 $(n-4i)!$ 提到前面来

$$\large\displaystyle(n-4i)!\sum_{i=0}^{a-i}\sum_{j=0}^{b-i}\sum_{k=0}^{c-i}\sum_{l=0}^{d-i}[a+b+c+d=n-4i]\frac{1}{i!}\frac{1}{j!}\frac{1}{k!}\frac{1}{l!}$$

后面显然是卷积的形式,可以 $\mathrm{NTT}$ 处理。

还有一种计算方式,就是把每种人的 $\mathrm{EGF}$ 乘起来,也可以得到这个结果。

这样就得到了至少 $i$ 堆人讨论 cxk 的方案数,假设这个东西是 $f_i$ 。

那么根据容斥可以得到 $\large ans=\sum\limits_{i=0}^n(-1)^if_i$ 。

时间复杂度 $O(n^2\log n)$ 。具体实现及细节见代码。

代码

// =================================
//   author: M_sea
//   website: http://m-sea-blog.com/
// =================================
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
#define re register
using namespace std;

inline int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int N=100000+10;
const int mod=998244353,g=3,gi=332748118;

inline void add(int& x,int y) { x=(x+y)%mod; }

inline int qpow(int a,int b) {
    int ans=1;
    for (;b;b>>=1,a=1ll*a*a%mod)
        if (b&1) ans=1ll*ans*a%mod;
    return ans;
}

int n,a,b,c,d;
int A[N],B[N],C[N],D[N],r[N];
int fac[N],ifac[N];

inline int C_(int n,int m) {
    if (n<m) return 0;
    return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}

inline void NTT(int* A,int n,int op) {
    for (re int i=0;i<n;++i) if (i<r[i]) swap(A[i],A[r[i]]);
    for (re int i=1;i<n;i<<=1) {
        int rot=qpow(op==1?g:gi,(mod-1)/(i<<1));
        for (re int j=0;j<n;j+=(i<<1)) {
            int w=1;
            for (re int k=0;k<i;++k,w=1ll*w*rot%mod) {
                int x=A[j+k],y=1ll*w*A[j+k+i]%mod;
                A[j+k]=(x+y)%mod,A[j+k+i]=(x-y+mod)%mod;
            }
        }
    }
    if (op==-1) {
        int iv=qpow(n,mod-2);
        for (re int i=0;i<n;++i) A[i]=1ll*A[i]*iv%mod;
    }
}

inline int calc(int n,int a,int b,int c,int d) {
    if (n>a+b+c+d) return 0;
    int lim=1,l=0;
    for (;lim<a+b+c+d;lim<<=1,++l);
    for (re int i=0;i<lim;++i) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    for (re int i=0;i<lim;++i) A[i]=(i<=a)?ifac[i]:0;
    for (re int i=0;i<lim;++i) B[i]=(i<=b)?ifac[i]:0;
    for (re int i=0;i<lim;++i) C[i]=(i<=c)?ifac[i]:0;
    for (re int i=0;i<lim;++i) D[i]=(i<=d)?ifac[i]:0;
    NTT(A,lim,1),NTT(B,lim,1),NTT(C,lim,1),NTT(D,lim,1);
    for (re int i=0;i<lim;++i) A[i]=1ll*A[i]%mod*B[i]%mod*C[i]%mod*D[i]%mod;
    NTT(A,lim,-1);
    return 1ll*A[n]*fac[n]%mod;
}

int main() {
    n=read(),a=read(),b=read(),c=read(),d=read();
    int lim=min(n/4,min(min(a,b),min(c,d)));

    fac[0]=1;
    for (re int i=1;i<=n;++i) fac[i]=1ll*fac[i-1]*i%mod;
    ifac[n]=qpow(fac[n],mod-2);
    for (re int i=n;i;--i) ifac[i-1]=1ll*ifac[i]*i%mod;
    
    int ans=0;
    for (re int i=0;i<=lim;++i) {
        int now=1ll*C_(n-3*i,i)*calc(n-4*i,a-i,b-i,c-i,d-i)%mod;
        if (i&1) add(ans,mod-now); else add(ans,now);
    }
    printf("%d\n",ans);
    return 0;
}
最后修改:2020 年 07 月 25 日 01 : 27 PM