Luogu

LOJ


orz laofu!!1

分析

$op = 0$

只需要统计两棵树有多少条重边即可。假设有 $c$ 条,答案即为 $y^{n - c}$。

$op = 1$

先特判掉 $y = 1$ 的情况,答案为 $n^{n - 2}$。

我们要求的是
$$
\sum_{T_2} y^{n - |T_1 \cap T_2|}
$$
交集不好处理,我们考虑这样一个式子
$$
f(S) = \sum_{T \subseteq S} \sum_{P \subseteq T} (-1)^{|T| - |P|} f(P) \tag{1} \label{eq1}
$$

证明(可能是伪的)

规定 $0 ^ 0 = 1$,可以证明在下面的推导中这是成立的。
$$
\begin{align}
\text{右边} = & \sum_{P \subseteq S} f(P) \sum_{j = 0}^{|S| - |P|} (-1)^j {|S| - |P| \choose j} \\
= & \sum_{P \subseteq S} f(P) (1 - 1)^{|S| - |P|} \\
= & f(S) = \text{左边}
\end{align}
$$

把 $\eqref{eq1}$ 套进去得到
$$
\sum_{T_2} \sum_{S \subseteq T_1 \cap T_2} \sum_{T \subseteq S} (-1) ^ {|S| - |T|} y^{n - |T|}
$$
设 $g(S)$ 为包含 $S$ 中边的生成树个数,上式可以变为
$$
\begin{align}
& \sum_{S \subseteq T_1} \left(\sum_{T \subseteq S} (-1)^{|S| - |T|} y^{n - |T|}\right) g(S) \\
= & \sum_{S \subseteq T_1} y^{n - |S|} \left(\sum_{T \subseteq S} (-y)^{n - |T|}\right) g(S) \\
= & \sum_{S \subseteq T_1} y^{n - |S|} \left(\sum_{i = 0}^{|S|} {|S| \choose i} (-y)^{n - i}\right) g(S) \\
= & \sum_{S \subseteq T_1} y^{n - |S|} (1 - y)^{|S|} g(S) \\
\end{align}
$$
考虑 $g(S)$ 等于什么。我们把 $S$ 中的边加入,那么原图会构成 $k$ 个大小为 $a_1, a_2, \cdots, a_k$ 的连通块,将它们连成一棵树的方案数为
$$
g(S) = n^{k - 2} \prod_{i = 1}^k a_i
$$

证明

于是上式变为
$$
\begin{align}
& \sum_{S \subseteq T_1} y^k (1 - y)^{n - k} n^{k - 2} \prod_{i = 1}^k a_i \\
= & \frac{(1 - y)^n}{n^2} \sum_{S \subseteq T_1} \left(\frac{ny}{1 - y}\right)^k \prod_{i = 1}^k a_i \\
= & \frac{(1 - y)^n}{n^2} \sum_{S \subseteq T_1} \prod_{i = 1}^k \frac{ny}{1 - y} a_i\\
\end{align}
$$
不妨设 $k = \frac{ny}{1 - y}$,上式的实际意义相当于每个大小为 $a$ 的连通块产生 $ka$ 的乘积贡献。

这样子就可以想到一个 DP:设 $f_{i, j}$ 表示以 $i$ 为根的子树、$i$ 所在的连通块大小为 $j$ 的答案,转移为树形背包。然而这样子是 $\mathcal{O}(n^2)$ 的。

考虑把 $a_i$ 写成 $a_i \choose 1$,相当于在每个连通块内选一个点产生 $ka_i$ 的乘积贡献。这样子就有一个更优的 DP:设 $f_{i, 0/1}$ 表示以 $i$ 为根的子树、$i$ 所在连通块是否已经选了一个点产生贡献,转移讨论一下即可。这样子即可做到 $\mathcal{O}(n)$。

$op = 2$

先判掉 $y = 1$ 的情况,答案为 $n^{2(n - 2)}$。

我们要求的是
$$
\sum_{T_1} \sum_{T_2} y^{n - |T_1 \cap T_2|}
$$
同样把 $\eqref{eq1}$ 套进去得到
$$
\sum_{T_1} \sum_{T_2} \sum_{S\subseteq T_1 \cap T_2} \sum_{T\subseteq S} (-1)^{|S| - |T|} y^{n - |T|}
$$
设 $g(S)$ 表示包含 $S$ 中边的生成树个数,上式可以变为(过程和上面类似就不写了)
$$
\sum_{S} y^{n - |S|} (1 - y)^{|S|} g(S)^2
$$

$$
g(S)^2 = n^{2k - 4} \prod_{i = 1}^k a_i^2
$$
代回去得到
$$
\begin{align}
& \sum_{S} y^k (1 - y)^{n - k} n^{2k - 4} \prod_{i = 1}^k a_i^2 \\
= & \frac{(1 - y)^n}{n^4}\sum_S\prod_{i = 1}^k \frac{n^2 y}{1 - y} a_i^2
\end{align}
$$
考察每个连通分量。一个大小为 $a$ 的连通分量会产生 $\frac{n^2 y}{1 - y} a^2$ 的乘积贡献,而 $a$ 个点的生成树有 $a^{a-2}$ 棵,所以其总贡献为 $\frac{n^2 y}{1 - y} a^a$。

而整张图是由若干个连通分量组合而成的,这让我们想到 EGF 中的 $\exp$。

于是构造 EGF $F(x) = \sum_{i \geq 1} \frac{n^2 y}{(1-y) i!} i^i x^i$,然后求 $G(x) = \exp(F(x))$,答案即为 $\frac{(1 - y)^n n!}{n^4}[x^n] G(x)$。


总结:这是一道好题,可是我什么都不会。

代码

// ====================================
//   author: M_sea
//   website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in", "r", stdin), freopen(#x".out", "w", stdout)
#define debug(...) fprintf(stderr, __VA_ARGS__)
typedef long long ll;

int read() {
    int X = 0, w = 1; char c = getchar();
    while (c < '0' || c > '9') { if (c == '-') w = -1; c = getchar(); }
    while (c >= '0' && c <= '9') X = X * 10 + c - '0', c = getchar();
    return X * w;
}

const int mod = 998244353;
int qpow(int a, int b) {
    int c = 1;
    for (; b; b >>= 1, a = 1ll * a * a % mod)
        if (b & 1) c = 1ll * c * a % mod;
    return c;
}

int n, y;

namespace Subtask0 {
    std::map<std::pair<int,int>, bool> M;

    void main() {
        for (int i = 1; i < n; ++i) {
            int u = read(), v = read();
            M[std::make_pair(u, v)] = M[std::make_pair(v, u)] = 1;
        }
        int c = n;
        for (int i = 1; i < n; ++i) {
            int u = read(), v = read();
            if (M.count(std::make_pair(u, v))) --c;
        }
        printf("%d\n", qpow(y, c));
    }
}

namespace Subtask1 {
    const int N = 100000 + 10;

    std::vector<int> E[N];
    int k, f[N][2];

    void dfs(int u, int fa) {
        f[u][0] = 1, f[u][1] = k;
        for (int v : E[u]) {
            if (v == fa) continue;
            dfs(v, u);
            int f0 = 1ll * f[u][0] * (f[v][0] + f[v][1]) % mod;
            int f1 = (1ll * f[u][0] * f[v][1] + 1ll * f[u][1] * f[v][0] + 1ll * f[u][1] * f[v][1]) % mod;
            f[u][0] = f0, f[u][1] = f1;
        }
    }

    void main() {
        if (y == 1) { printf("%d\n", qpow(n, n - 2)); return; }
        for (int i = 1; i < n; ++i) {
            int u = read(), v = read();
            E[u].emplace_back(v), E[v].emplace_back(u);
        }
        k = 1ll * n * y % mod * qpow(1 - y + mod, mod - 2) % mod;
        dfs(1, 0);
        int ans = 1ll * f[1][1] * qpow(1 - y + mod, n) % mod * qpow(n, mod - 3) % mod;
        printf("%d\n", ans);
    }
}

namespace Subtask2 {
    const int N = 262144 + 10;

    int r[N];
    void NTT(int *A, int n, int op) {
        for (int i = 0; i < n; ++i) if (i < r[i]) std::swap(A[i], A[r[i]]);
        for (int i = 1; i < n; i <<= 1) {
            int rot = qpow(op == 1 ? 3 : 332748118, (mod - 1) / (i << 1));
            for (int j = 0; j < n; j += i << 1)
                for (int k = 0, w = 1; k < i; ++k, w = 1ll * w * rot % mod) {
                    int x = A[j + k], y = 1ll * w * A[j + k + i] % mod;
                    A[j + k] = (x + y) % mod, A[j + k + i] = (x - y + mod) % mod;
                }
        }
        if (op == -1) {
            int inv = qpow(n, mod - 2);
            for (int i = 0; i < n; ++i) A[i] = 1ll * A[i] * inv % mod;
        }
    }
    int NTT_init(int n) {
        int lim = 1, l = 0;
        for (; lim < n; lim <<= 1, ++l);
        for (int i = 0; i < lim; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
        return lim;
    }
    void PolyInv(int *F, int *G, int n) {
        static int A[N], B[N];
        if (n == 1) { G[0] = qpow(F[0], mod - 2); return; }
        PolyInv(F, G, n >> 1);
        for (int i = 0; i < n; ++i) A[i] = F[i], B[i] = G[i];
        int lim = NTT_init(n << 1);
        NTT(A, lim, 1), NTT(B, lim, 1);
        for (int i = 0; i < lim; ++i) A[i] = 1ll * A[i] * B[i] % mod * B[i] % mod;
        NTT(A, lim, -1);
        for (int i = 0; i < n; ++i) G[i] = (2ll * G[i] - A[i] + mod) % mod;
        for (int i = 0; i < lim; ++i) A[i] = B[i] = 0;
    }
    void PolyDeri(int *F, int *G, int n) {
        for (int i = 1; i < n; ++i) G[i - 1] = 1ll * F[i] * i %mod;
        G[n - 1] = 0;
    }
    void PolyInte(int *F, int *G, int n) {
        for (int i = 1; i < n; ++i) G[i] = 1ll * F[i-1] * qpow(i, mod - 2) % mod;
        G[0] = 0;
    }
    void PolyLn(int *F, int *G, int n) {
        static int A[N], B[N];
        PolyDeri(F, A, n), PolyInv(F, B, n);
        int lim = NTT_init(n << 1);
        NTT(A, lim, 1), NTT(B, lim, 1);
        for (int i = 0; i < lim; ++i) A[i] = 1ll * A[i] * B[i] % mod;
        NTT(A, lim, -1);
        PolyInte(A, G, n);
        for (int i = 0; i < lim; ++i) A[i] = B[i] = 0;
    }
    void PolyExp(int *F, int *G, int n) {
        static int A[N], B[N];
        if (n == 1) { G[0] = 1; return; }
        PolyExp(F, G, n >> 1);
        for (int i = 0; i < n; ++i) A[i] = G[i];
        PolyLn(G, B, n);
        for (int i = 0; i < n; ++i) B[i] = (mod - B[i] + F[i]) % mod;
        B[0] = (B[0] + 1) % mod;
        int lim = NTT_init(n << 1);
        NTT(A, lim, 1), NTT(B, lim, 1);
        for (int i = 0; i < lim; ++i) A[i] = 1ll * A[i] * B[i] % mod;
        NTT(A, lim, -1);
        for (int i = 0; i < n; ++i) G[i] = A[i];
        for (int i = 0; i < lim; ++i) A[i] = B[i] = 0;
    }

    int fac[N], ifac[N];
    int F[N], G[N];

    void main() {
        if (y == 1) { printf("%d\n", qpow(n, 2 * (n - 2))); return; }
        fac[0] = 1;
        for (int i = 1; i <= n; ++i) fac[i] = 1ll * fac[i - 1] * i % mod;
        ifac[n] = qpow(fac[n], mod - 2);
        for (int i = n; i; --i) ifac[i - 1] = 1ll * ifac[i] * i % mod;
        int k = 1ll * n * n % mod * y % mod * qpow(1 - y + mod, mod - 2) % mod;
        for (int i = 1; i <= n; ++i) F[i] = 1ll * k * qpow(i, i) % mod * ifac[i] % mod;
        int lim = 1; for (; lim <= n; lim <<= 1);
        PolyExp(F, G, lim);
        int ans = 1ll * G[n] * fac[n] % mod * qpow(1 - y + mod, n) % mod * qpow(n, mod - 5) % mod;
        printf("%d\n", ans);
    }
}

int main() {
    n = read(), y = read(); int op = read();
    switch (op) {
        case 0: Subtask0::main(); break;
        case 1: Subtask1::main(); break;
        case 2: Subtask2::main(); break;
    }
    return 0;
}
最后修改:2021 年 03 月 13 日 08 : 06 PM