LOJ

分析

考虑求出交集大小至少为 $k$ 的方案数 $f(k)$,显然有

$$ f(k) = {n \choose k} (2 ^ {2 ^ {n - k}} - 1) $$

考虑构造容斥系数 $c(k)$,使得

$$ ans = \sum_{k = 0}^n f(k)c(k) $$

对于交集大小恰好为 $k$ 的一组方案,它的贡献应为 $[4 | k]$,而它的贡献实际为 $\sum_{i = 0}^k {k \choose i}f(i)$。所以应有

$$ [4 | k] = \sum_{i = 0}^k {k \choose i} f(i) $$

二项式反演得到

$$ f(k) = \sum_{i = 0}^k (-1) ^ {k - i} {k \choose i} [4 | i] $$

再单位根反演得到

$$ f(k) = \frac{1}{4} \sum_{i = 0}^k (-1) ^ {k - i} {k \choose i} \sum_{j = 0}^3 \omega^{ij} $$

交换求和号

$$ f(k) = \frac{1}{4} \sum_{j = 0}^3 \sum_{i = 0}^k {k \choose i} (-1) ^ {k - i} \omega ^ {ij} $$

根据二项式定理得到

$$ f(k) = \frac{1}{4} \sum_{j = 0}^3 (\omega^j - 1) ^ k $$

从小到大枚举 $k$,动态维护 $(\omega ^ j - 1) ^ k$ 即可 $\mathcal{O}(n)$ 计算。

代码

// ====================================
//   author: M_sea
//   website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in", "r", stdin), freopen(#x".out", "w", stdout)
#define debug(...) fprintf(stderr, __VA_ARGS__)
using namespace std;
typedef long long ll;

int read() {
    int X = 0, w = 1; char c = getchar();
    while (c < '0' || c > '9') { if (c == '-') w = -1; c = getchar(); }
    while (c >= '0' && c <= '9') X = X * 10 + c - '0', c = getchar();
    return X * w;
}

const int N = 10000000 + 10;
const int mod = 998244353, inv4 = 748683265;
const int w[4] = {1, 911660635, 998244352, 86583718};
int upd(int x) { return x + ((x >> 31) & mod); }
int qpow(int a, int b) {
    int c = 1;
    for (; b; b >>= 1, a = 1ll * a * a % mod)
        if (b & 1) c = 1ll * c * a % mod;
    return c;
}

int n;
int fac[N], ifac[N], pw[N];
int mul[4];

int C(int n, int m) {
    if (n < m) return 0;
    return 1ll * fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}

int calc(int i) {
    int c = 0;
    c = upd(c + mul[0] - mod), mul[0] = 1ll * mul[0] * (w[0] - 1) % mod; // 其实这个贡献等于 [i == 0]
    c = upd(c + mul[1] - mod), mul[1] = 1ll * mul[1] * (w[1] - 1) % mod;
    c = upd(c + mul[2] - mod), mul[2] = 1ll * mul[2] * (w[2] - 1) % mod;
    c = upd(c + mul[3] - mod), mul[3] = 1ll * mul[3] * (w[3] - 1) % mod;
    return 1ll * C(n, i) * (pw[n - i] - 1 + mod) % mod * c % mod;
}

int main() {
    n = read();
    fac[0] = 1;
    for (int i = 1; i <= n; ++i) fac[i] = 1ll * fac[i - 1] * i % mod;
    ifac[n] = qpow(fac[n], mod - 2);
    for (int i = n; i; --i) ifac[i - 1] = 1ll * ifac[i] * i % mod;
    pw[0] = 2;
    for (int i = 1; i <= n; ++i) pw[i] = 1ll * pw[i - 1] * pw[i - 1] % mod;
    mul[0] = mul[1] = mul[2] = mul[3] = 1; int ans = 0;
    for (int i = 0; i <= n; ++i) ans = upd(ans + calc(i) - mod);
    ans = (1ll * ans * inv4 + 1) % mod;
    printf("%d\n", ans);
    return 0;
}
最后修改:2021 年 02 月 22 日 09 : 36 AM