球与盒子问题大杂烩

下面我们对这 $12$ 个问题简略分析。

I

球之间互不相同,盒子之间互不相同。

每个球都可以任意选择一个盒子,答案为 $m^n$。

II

球之间互不相同,盒子之间互不相同,每个盒子至多装一个球。

每个球选择一个盒子,已经被选择过的盒子不能再被选择,答案为 $m^{\underline{n}}$。

III

球之间互不相同,盒子之间互不相同,每个盒子至少装一个球。

考虑容斥,枚举空盒子的个数,不难得到答案为
$$
\sum _ {i = 0} ^ m (-1) ^ i {m \choose i} (m - i) ^ n
$$

IV

球之间互不相同,盒子全部相同。

枚举非空盒子的个数,不难得到答案为
$$
\sum_{i = 0} ^ m \begin{Bmatrix} n \\ i \end{Bmatrix}
$$
使用 第二类斯特林数·行 的方法计算即可。具体可以看这里

V

球之间互不相同,盒子全部相同,每个盒子至多装一个球。

答案为 $[n \leq m]$。

VI

球之间互不相同,盒子全部相同,每个盒子至少装一个球。

答案为 $\begin{Bmatrix} n \\ m \end{Bmatrix}$。

VII

球全部相同,盒子之间互不相同。

使用插板法计算,答案为 ${n + m - 1 \choose m - 1}$。

VIII

球全部相同,盒子之间互不相同,每个盒子至多装一个球。

相当于选 $n$ 个盒子放球,答案为 ${m \choose n}$。

IX

球全部相同,盒子之间互不相同,每个盒子至少装一个球。

使用插板法计算,答案为 ${n - 1 \choose m - 1}$。

X

球全部相同,盒子全部相同。

定义划分数 $p_{n, m}$ 为 $n$ 划分为 $m$ 个自然数的方案数。有一个很经典的转移
$$
p_{n, m} = p_{n, m - 1} + p_{n - m, m}
$$
意思是:加一个 $0$,或者把所有数 $+1$。

设生成函数 $F_m(x) = \sum_{i \geq 0} p_{i, m} x ^ i$,那么有
$$
F_m(x) = F_{m-1}(x) (1 + x ^ m + x ^ {2m} + \cdots) = \frac{F_{m - 1}(x)}{1 - x^m}
$$
从而
$$
F_m(x) = \prod_{i = 1} ^ m \frac{1}{1 - x^i}
$$
考虑取对数。通过一些推导可以知道
$$
\ln \frac{1}{1 - x^i} = \sum_{j\geq 1} \frac{x^{ij}}{j}
$$
这样子就可以 $\mathcal{O}(m\ln m)$ 求出 $\ln F_m(x)$,再多项式 exp 回去即可。

XI

球全部相同,盒子全部相同,每个盒子至多装一个球。

答案为 $[n \leq m]$。

XII

球全部相同,盒子全部相同,每个盒子至少装一个球。

相当于先往每个盒子里装一个球,答案为 $p_{n - m, m}$。

代码

// ====================================
//   author: M_sea
//   website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in", "r", stdin), freopen(#x".out", "w", stdout)
#define debug(...) fprintf(stderr, __VA_ARGS__)
using namespace std;
typedef long long ll;

int read() {
    int X = 0, w = 1; char c = getchar();
    while (c < '0' || c > '9') { if (c == '-') w = -1; c = getchar(); }
    while (c >= '0' && c <= '9') X = X * 10 + c - '0', c = getchar();
    return X * w;
}

const int N = 524288 + 10;
const int mod = 998244353;
int qpow(int a, int b) {
    int c = 1;
    for (; b; b >>= 1, a = 1ll * a * a % mod)
        if (b & 1) c = 1ll * c * a % mod;
    return c;
}

int n, m;
int fac[N], ifac[N], S[N], P[N];

int r[N];
void NTT(int *A, int n, int op) {
    for (int i = 0; i < n; ++i)
        if (i < r[i]) swap(A[i], A[r[i]]);
    for (int i = 1; i < n; i <<= 1) {
        int rot = qpow(op == 1 ? 3 : 332748118, (mod - 1) / (i << 1));
        for (int j = 0; j < n; j += i << 1)
            for (int k = 0, w = 1; k < i; ++k, w = 1ll * w * rot % mod) {
                int x = A[j + k], y = 1ll * A[j + k + i] * w % mod;
                A[j + k] = (x + y) % mod, A[j + k + i] = (x - y + mod) % mod;
            }
    }
    if (op == -1) {
        int inv = qpow(n, mod - 2);
        for (int i = 0; i < n; ++i) A[i] = 1ll * A[i] * inv % mod;
    }
}
int NTT_init(int n) {
    int lim = 1, l = 0;
    for (; lim < n; lim <<= 1, ++l);
    for (int i = 0; i < lim; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    return lim;
}
void PolyInv(int *F, int *G, int n) {
    static int A[N], B[N];
    if (n == 1) { G[0] = qpow(F[0], mod - 2); return; }
    PolyInv(F, G, n >> 1);
    for (int i = 0; i < n; ++i) A[i] = F[i], B[i] = G[i];
    int lim = NTT_init(n << 1);
    NTT(A, lim, 1), NTT(B, lim, 1);
    for (int i = 0; i < lim; ++i) A[i] = 1ll * A[i] * B[i] % mod * B[i] % mod;
    NTT(A, lim, -1);
    for (int i = 0; i < n; ++i) G[i] = (2ll * G[i] - A[i] + mod) % mod;
    for (int i = 0; i < lim; ++i) A[i] = B[i] = 0;
}
void PolyDeri(int *F, int *G, int n) {
    for (int i = 1; i < n; ++i) G[i - 1] = 1ll * F[i] * i %mod;
    G[n - 1] = 0;
}
void PolyInte(int *F, int *G, int n) {
    for (int i = 1; i < n; ++i) G[i] = 1ll * F[i-1] * qpow(i, mod - 2) % mod;
    G[0] = 0;
}
void PolyLn(int *F, int *G, int n) {
    static int A[N], B[N];
    PolyDeri(F, A, n), PolyInv(F, B, n);
    int lim = NTT_init(n << 1);
    NTT(A, lim, 1), NTT(B, lim, 1);
    for (int i = 0; i < lim; ++i) A[i] = 1ll * A[i] * B[i] % mod;
    NTT(A, lim, -1);
    PolyInte(A, G, n);
    for (int i = 0; i < lim; ++i) A[i] = B[i] = 0;
}
void PolyExp(int *F, int *G, int n) {
    static int A[N], B[N];
    if (n == 1) { G[0] = 1; return; }
    PolyExp(F, G, n >> 1);
    for (int i = 0; i < n; ++i) A[i] = G[i];
    PolyLn(G, B, n);
    for (int i = 0; i < n; ++i) B[i] = (mod - B[i] + F[i]) % mod;
    B[0] = (B[0] + 1) % mod;
    int lim = NTT_init(n << 1);
    NTT(A, lim, 1), NTT(B, lim, 1);
    for (int i = 0; i < lim; ++i) A[i] = 1ll * A[i] * B[i] % mod;
    NTT(A, lim, -1);
    for (int i = 0; i < n; ++i) G[i] = A[i];
    for (int i = 0; i < lim; ++i) A[i] = B[i] = 0;
}

void init() {
    static int F[N], G[N]; int lim;

    fac[0] = 1;
    for (int i = 1; i <= n + m; ++i) fac[i] = 1ll * fac[i - 1] * i % mod;
    ifac[n + m] = qpow(fac[n + m], mod - 2);
    for (int i = n + m; i; --i) ifac[i - 1] = 1ll * ifac[i] * i % mod;

    for (int i = 0; i <= n; ++i) {
        F[i] = i & 1 ? mod - ifac[i] : ifac[i];
        G[i] = 1ll * qpow(i, n) * ifac[i] % mod;
    }
    lim = NTT_init(n << 1 | 1);
    NTT(F, lim, 1), NTT(G, lim, 1);
    for (int i = 0; i < lim; ++i) F[i] = 1ll * F[i] * G[i] % mod;
    NTT(F, lim, -1);
    for (int i = 0; i <= n; ++i) S[i] = F[i];
    for (int i = 0; i < lim; ++i) F[i] = G[i] = 0;

    lim = 1;
    for (; lim <= n; lim <<= 1);
    for (int i = 1; i <= m; ++i)
        for (int j = i; j <= n; j += i)
            F[j] = (F[j] + qpow(j / i, mod - 2)) % mod;
    PolyExp(F, G, lim);
    for (int i = 0; i <= n; ++i) P[i] = G[i];
}

int C(int n, int m) {
    if (n < m) return 0;
    return 1ll * fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}

int I() { return qpow(m, n); }
int II() { return n > m ? 0 : 1ll * fac[m] * ifac[m - n] % mod; }
int III() {
    int res = 0;
    for (int i = 0; i <= m; ++i) {
        int w = 1ll * C(m, i) * qpow(m - i, n) % mod;
        if (i & 1) res = (res - w + mod) % mod;
        else res = (res + w) % mod;
    }
    return res;
}
int IV() {
    int res = 0;
    for (int i = 1; i <= m; ++i) res = (res + S[i]) % mod;
    return res;
}
int V() { return n <= m; }
int VI() { return S[m]; }
int VII() { return C(n + m - 1, m - 1); }
int VIII() { return C(m, n); }
int IX() { return C(n - 1, m - 1); }
int X() { return P[n]; }
int XI() { return n <= m; }
int XII() { return n < m ? 0 : P[n - m]; }

int main() {
    n = read(), m = read(); init();
    printf("%d\n%d\n%d\n", I(), II(), III());
    printf("%d\n%d\n%d\n", IV(), V(), VI());
    printf("%d\n%d\n%d\n", VII(), VIII(), IX());
    printf("%d\n%d\n%d\n", X(), XI(), XII());
    return 0;
}
最后修改:2021 年 04 月 07 日 03 : 12 PM