下面我们对这 $12$ 个问题简略分析。
I
球之间互不相同,盒子之间互不相同。
每个球都可以任意选择一个盒子,答案为 $m^n$。
II
球之间互不相同,盒子之间互不相同,每个盒子至多装一个球。
每个球选择一个盒子,已经被选择过的盒子不能再被选择,答案为 $m^{\underline{n}}$。
III
球之间互不相同,盒子之间互不相同,每个盒子至少装一个球。
考虑容斥,枚举空盒子的个数,不难得到答案为
$$
\sum _ {i = 0} ^ m (-1) ^ i {m \choose i} (m - i) ^ n
$$
IV
球之间互不相同,盒子全部相同。
枚举非空盒子的个数,不难得到答案为
$$
\sum_{i = 0} ^ m \begin{Bmatrix} n \\ i \end{Bmatrix}
$$
使用 第二类斯特林数·行 的方法计算即可。具体可以看这里。
V
球之间互不相同,盒子全部相同,每个盒子至多装一个球。
答案为 $[n \leq m]$。
VI
球之间互不相同,盒子全部相同,每个盒子至少装一个球。
答案为 $\begin{Bmatrix} n \\ m \end{Bmatrix}$。
VII
球全部相同,盒子之间互不相同。
使用插板法计算,答案为 ${n + m - 1 \choose m - 1}$。
VIII
球全部相同,盒子之间互不相同,每个盒子至多装一个球。
相当于选 $n$ 个盒子放球,答案为 ${m \choose n}$。
IX
球全部相同,盒子之间互不相同,每个盒子至少装一个球。
使用插板法计算,答案为 ${n - 1 \choose m - 1}$。
X
球全部相同,盒子全部相同。
定义划分数 $p_{n, m}$ 为 $n$ 划分为 $m$ 个自然数的方案数。有一个很经典的转移
$$
p_{n, m} = p_{n, m - 1} + p_{n - m, m}
$$
意思是:加一个 $0$,或者把所有数 $+1$。
设生成函数 $F_m(x) = \sum_{i \geq 0} p_{i, m} x ^ i$,那么有
$$
F_m(x) = F_{m-1}(x) (1 + x ^ m + x ^ {2m} + \cdots) = \frac{F_{m - 1}(x)}{1 - x^m}
$$
从而
$$
F_m(x) = \prod_{i = 1} ^ m \frac{1}{1 - x^i}
$$
考虑取对数。通过一些推导可以知道
$$
\ln \frac{1}{1 - x^i} = \sum_{j\geq 1} \frac{x^{ij}}{j}
$$
这样子就可以 $\mathcal{O}(m\ln m)$ 求出 $\ln F_m(x)$,再多项式 exp 回去即可。
XI
球全部相同,盒子全部相同,每个盒子至多装一个球。
答案为 $[n \leq m]$。
XII
球全部相同,盒子全部相同,每个盒子至少装一个球。
相当于先往每个盒子里装一个球,答案为 $p_{n - m, m}$。
代码
// ====================================
// author: M_sea
// website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in", "r", stdin), freopen(#x".out", "w", stdout)
#define debug(...) fprintf(stderr, __VA_ARGS__)
using namespace std;
typedef long long ll;
int read() {
int X = 0, w = 1; char c = getchar();
while (c < '0' || c > '9') { if (c == '-') w = -1; c = getchar(); }
while (c >= '0' && c <= '9') X = X * 10 + c - '0', c = getchar();
return X * w;
}
const int N = 524288 + 10;
const int mod = 998244353;
int qpow(int a, int b) {
int c = 1;
for (; b; b >>= 1, a = 1ll * a * a % mod)
if (b & 1) c = 1ll * c * a % mod;
return c;
}
int n, m;
int fac[N], ifac[N], S[N], P[N];
int r[N];
void NTT(int *A, int n, int op) {
for (int i = 0; i < n; ++i)
if (i < r[i]) swap(A[i], A[r[i]]);
for (int i = 1; i < n; i <<= 1) {
int rot = qpow(op == 1 ? 3 : 332748118, (mod - 1) / (i << 1));
for (int j = 0; j < n; j += i << 1)
for (int k = 0, w = 1; k < i; ++k, w = 1ll * w * rot % mod) {
int x = A[j + k], y = 1ll * A[j + k + i] * w % mod;
A[j + k] = (x + y) % mod, A[j + k + i] = (x - y + mod) % mod;
}
}
if (op == -1) {
int inv = qpow(n, mod - 2);
for (int i = 0; i < n; ++i) A[i] = 1ll * A[i] * inv % mod;
}
}
int NTT_init(int n) {
int lim = 1, l = 0;
for (; lim < n; lim <<= 1, ++l);
for (int i = 0; i < lim; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
return lim;
}
void PolyInv(int *F, int *G, int n) {
static int A[N], B[N];
if (n == 1) { G[0] = qpow(F[0], mod - 2); return; }
PolyInv(F, G, n >> 1);
for (int i = 0; i < n; ++i) A[i] = F[i], B[i] = G[i];
int lim = NTT_init(n << 1);
NTT(A, lim, 1), NTT(B, lim, 1);
for (int i = 0; i < lim; ++i) A[i] = 1ll * A[i] * B[i] % mod * B[i] % mod;
NTT(A, lim, -1);
for (int i = 0; i < n; ++i) G[i] = (2ll * G[i] - A[i] + mod) % mod;
for (int i = 0; i < lim; ++i) A[i] = B[i] = 0;
}
void PolyDeri(int *F, int *G, int n) {
for (int i = 1; i < n; ++i) G[i - 1] = 1ll * F[i] * i %mod;
G[n - 1] = 0;
}
void PolyInte(int *F, int *G, int n) {
for (int i = 1; i < n; ++i) G[i] = 1ll * F[i-1] * qpow(i, mod - 2) % mod;
G[0] = 0;
}
void PolyLn(int *F, int *G, int n) {
static int A[N], B[N];
PolyDeri(F, A, n), PolyInv(F, B, n);
int lim = NTT_init(n << 1);
NTT(A, lim, 1), NTT(B, lim, 1);
for (int i = 0; i < lim; ++i) A[i] = 1ll * A[i] * B[i] % mod;
NTT(A, lim, -1);
PolyInte(A, G, n);
for (int i = 0; i < lim; ++i) A[i] = B[i] = 0;
}
void PolyExp(int *F, int *G, int n) {
static int A[N], B[N];
if (n == 1) { G[0] = 1; return; }
PolyExp(F, G, n >> 1);
for (int i = 0; i < n; ++i) A[i] = G[i];
PolyLn(G, B, n);
for (int i = 0; i < n; ++i) B[i] = (mod - B[i] + F[i]) % mod;
B[0] = (B[0] + 1) % mod;
int lim = NTT_init(n << 1);
NTT(A, lim, 1), NTT(B, lim, 1);
for (int i = 0; i < lim; ++i) A[i] = 1ll * A[i] * B[i] % mod;
NTT(A, lim, -1);
for (int i = 0; i < n; ++i) G[i] = A[i];
for (int i = 0; i < lim; ++i) A[i] = B[i] = 0;
}
void init() {
static int F[N], G[N]; int lim;
fac[0] = 1;
for (int i = 1; i <= n + m; ++i) fac[i] = 1ll * fac[i - 1] * i % mod;
ifac[n + m] = qpow(fac[n + m], mod - 2);
for (int i = n + m; i; --i) ifac[i - 1] = 1ll * ifac[i] * i % mod;
for (int i = 0; i <= n; ++i) {
F[i] = i & 1 ? mod - ifac[i] : ifac[i];
G[i] = 1ll * qpow(i, n) * ifac[i] % mod;
}
lim = NTT_init(n << 1 | 1);
NTT(F, lim, 1), NTT(G, lim, 1);
for (int i = 0; i < lim; ++i) F[i] = 1ll * F[i] * G[i] % mod;
NTT(F, lim, -1);
for (int i = 0; i <= n; ++i) S[i] = F[i];
for (int i = 0; i < lim; ++i) F[i] = G[i] = 0;
lim = 1;
for (; lim <= n; lim <<= 1);
for (int i = 1; i <= m; ++i)
for (int j = i; j <= n; j += i)
F[j] = (F[j] + qpow(j / i, mod - 2)) % mod;
PolyExp(F, G, lim);
for (int i = 0; i <= n; ++i) P[i] = G[i];
}
int C(int n, int m) {
if (n < m) return 0;
return 1ll * fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
int I() { return qpow(m, n); }
int II() { return n > m ? 0 : 1ll * fac[m] * ifac[m - n] % mod; }
int III() {
int res = 0;
for (int i = 0; i <= m; ++i) {
int w = 1ll * C(m, i) * qpow(m - i, n) % mod;
if (i & 1) res = (res - w + mod) % mod;
else res = (res + w) % mod;
}
return res;
}
int IV() {
int res = 0;
for (int i = 1; i <= m; ++i) res = (res + S[i]) % mod;
return res;
}
int V() { return n <= m; }
int VI() { return S[m]; }
int VII() { return C(n + m - 1, m - 1); }
int VIII() { return C(m, n); }
int IX() { return C(n - 1, m - 1); }
int X() { return P[n]; }
int XI() { return n <= m; }
int XII() { return n < m ? 0 : P[n - m]; }
int main() {
n = read(), m = read(); init();
printf("%d\n%d\n%d\n", I(), II(), III());
printf("%d\n%d\n%d\n", IV(), V(), VI());
printf("%d\n%d\n%d\n", VII(), VIII(), IX());
printf("%d\n%d\n%d\n", X(), XI(), XII());
return 0;
}