Luogu

LOJ

分析

设 $P=\sum_{i=1}^n p_i$。

设 $F(x)$ 为按到目标状态的指数型概率生成函数,$G(x)$ 为按回自己的指数型概率生成函数。

$F(x)$ 直接把每个开关乘在一起即可,$G(x)$ 相当于是 $s_i=0$ 时的 $F(x)$。那么有

$$ \begin{aligned} F(x)&=\prod_{i=1}^n\frac{e^{\frac{p_i}{P}x}+(-1)^{s_i}e^{-\frac{p_i}{P}x}}{2}\\ G(x)&=\prod_{i=1}^n\frac{e^{\frac{p_i}{P}x}+e^{-\frac{p_i}{P}x}}{2} \end{aligned} $$

我们设 $F(x)$ 对应的普通型生成函数为 $f(x)$,$G(x)$ 对应的普通型生成函数为 $g(x)$,$h(x)$ 为第一次按到目标状态的普通型概率生成函数。那么有

$$ h(x)g(x)=f(x) $$

也就是说

$$ h(x)=\frac{f(x)}{g(x)} $$

根据一些概率生成函数的基础知识,我们需要求的就是 $h'(1)$。

上面是大体思路,然而还有很多小细节,我们来一一解决。

首先是怎么把 $F(x)$ 和 $G(x)$ 对应的 OGF 求出来。可以发现,$F(x)$ 是由若干个 $a_ie^{\frac{i}{P}x}$ 组成的,这里 $i\in[-P,P]$。

于是我们想办法把 $F(x)$ 表示成 $\sum_{i=-P}^Pa_ie^{\frac{i}{P}x}$ 的形式,这样就可以知道 $f(x)=\sum_{i=-P}^P\frac{a_i}{1-\frac{i}{P}x}$。

这个 $a_i$ 很好求出,只需要做一遍背包即可。

然后考虑怎么求 $h'(1)$。我们直接套公式

$$ h'(1)=\frac{f'(1)g(1)-f(1)g'(1)}{g(1)^2} $$

我们只需要求 $f(1),f'(1),g(1),g'(1)$ 即可。然而有一个很大的问题是因为存在 $\frac{a_P}{1-x}$ 这一项,所以这四个东西在 $x=1$ 处是不收敛的……

于是可以考虑乘上一个 $1-x$,那么可以得到

$$ \begin{aligned} f(1)&=a_P\\ f'(1)&=\sum_{i=-P}^{P-1}\frac{a_i}{\frac{i}{P}-1} \end{aligned} $$

$g(1)$ 和 $g'(1)$ 类似。

那么

$$ \begin{aligned} h'(1)&=\frac{\left(\sum_{i=-P}^{P-1}\frac{a_i}{\frac{i}{P}-1}\right)b_P-\left(\sum_{i=-P}^{P-1}\frac{b_i}{\frac{i}{P}-1}\right)a_P}{b_P\!^2}\\ &=\sum_{i=-P}^{P-1}\frac{a_ib_P-b_ia_P}{\left(\frac{i}{p}-1\right)b_P\!^2} \end{aligned} $$

直接计算即可。

代码

// ====================================
//   author: M_sea
//   website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin); freopen(#x".out","w",stdout)
#define debug(...) fprintf(stderr,__VA_ARGS__)
using namespace std;
typedef long long ll;

int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int N=100000+10,L=50000;
const int mod=998244353,inv2=499122177;
int qpow(int a,int b) { int c=1;
    for (;b;b>>=1,a=1ll*a*a%mod) if (b&1) c=1ll*c*a%mod;
    return c;
}

int n,s[N],p[N],sp=0;
int a[2][N],b[2][N];

int main() {
    n=read();
    for (int i=1;i<=n;++i) s[i]=read()?mod-1:1;
    a[0][L]=b[0][L]=1;
    for (int i=1;i<=n;++i) {
        int cur=i&1,pre=cur^1;
        memset(a[cur],0,sizeof(a[cur])),memset(b[cur],0,sizeof(b[cur]));
        p[i]=read();
        for (int j=-sp;j<=sp;++j) {
            a[cur][j+p[i]+L]=(a[cur][j+p[i]+L]+1ll*inv2*a[pre][j+L])%mod;
            a[cur][j-p[i]+L]=(a[cur][j-p[i]+L]+1ll*s[i]*inv2%mod*a[pre][j+L])%mod;
            b[cur][j+p[i]+L]=(b[cur][j+p[i]+L]+1ll*inv2*b[pre][j+L])%mod;
            b[cur][j-p[i]+L]=(b[cur][j-p[i]+L]+1ll*inv2*b[pre][j+L])%mod;
        }
        sp+=p[i];
    }
    int inv=qpow(sp,mod-2),ans=0;
    for (int i=-sp;i<sp;++i) {
        int x=(1ll*a[n&1][i+L]*b[n&1][sp+L]%mod-1ll*b[n&1][i+L]*a[n&1][sp+L]%mod+mod)%mod;
        int y=(1ll*(i+mod)*inv-1+mod)%mod;
        ans=(ans+1ll*x*qpow(y,mod-2))%mod;
    }
    ans=1ll*ans*qpow(b[n&1][sp+L],mod-3)%mod;
    printf("%d\n",ans);
    return 0;
}
最后修改:2020 年 10 月 26 日 07 : 47 PM