分析
设 $P=\sum_{i=1}^n p_i$。
设 $F(x)$ 为按到目标状态的指数型概率生成函数,$G(x)$ 为按回自己的指数型概率生成函数。
$F(x)$ 直接把每个开关乘在一起即可,$G(x)$ 相当于是 $s_i=0$ 时的 $F(x)$。那么有
$$
\begin{aligned}
F(x)&=\prod_{i=1}^n\frac{e^{\frac{p_i}{P}x}+(-1)^{s_i}e^{-\frac{p_i}{P}x}}{2}\\
G(x)&=\prod_{i=1}^n\frac{e^{\frac{p_i}{P}x}+e^{-\frac{p_i}{P}x}}{2}
\end{aligned}
$$
我们设 $F(x)$ 对应的普通型生成函数为 $f(x)$,$G(x)$ 对应的普通型生成函数为 $g(x)$,$h(x)$ 为第一次按到目标状态的普通型概率生成函数。那么有
$$
h(x)g(x)=f(x)
$$
也就是说
$$
h(x)=\frac{f(x)}{g(x)}
$$
根据一些概率生成函数的基础知识,我们需要求的就是 $h'(1)$。
上面是大体思路,然而还有很多小细节,我们来一一解决。
首先是怎么把 $F(x)$ 和 $G(x)$ 对应的 OGF 求出来。可以发现,$F(x)$ 是由若干个 $a_ie^{\frac{i}{P}x}$ 组成的,这里 $i\in[-P,P]$。
于是我们想办法把 $F(x)$ 表示成 $\sum_{i=-P}^Pa_ie^{\frac{i}{P}x}$ 的形式,这样就可以知道 $f(x)=\sum_{i=-P}^P\frac{a_i}{1-\frac{i}{P}x}$。
这个 $a_i$ 很好求出,只需要做一遍背包即可。
然后考虑怎么求 $h'(1)$。我们直接套公式
$$
h'(1)=\frac{f'(1)g(1)-f(1)g'(1)}{g(1)^2}
$$
我们只需要求 $f(1),f'(1),g(1),g'(1)$ 即可。然而有一个很大的问题是因为存在 $\frac{a_P}{1-x}$ 这一项,所以这四个东西在 $x=1$ 处是不收敛的……
于是可以考虑乘上一个 $1-x$,那么可以得到
$$
\begin{aligned}
f(1)&=a_P\\
f'(1)&=\sum_{i=-P}^{P-1}\frac{a_i}{\frac{i}{P}-1}
\end{aligned}
$$
$g(1)$ 和 $g'(1)$ 类似。
那么
$$
\begin{aligned}
h'(1)&=\frac{\left(\sum_{i=-P}^{P-1}\frac{a_i}{\frac{i}{P}-1}\right)b_P-\left(\sum_{i=-P}^{P-1}\frac{b_i}{\frac{i}{P}-1}\right)a_P}{b_P\!^2}\\
&=\sum_{i=-P}^{P-1}\frac{a_ib_P-b_ia_P}{\left(\frac{i}{p}-1\right)b_P\!^2}
\end{aligned}
$$
直接计算即可。
代码
// ====================================
// author: M_sea
// website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin); freopen(#x".out","w",stdout)
#define debug(...) fprintf(stderr,__VA_ARGS__)
using namespace std;
typedef long long ll;
int read() {
int X=0,w=1; char c=getchar();
while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
return X*w;
}
const int N=100000+10,L=50000;
const int mod=998244353,inv2=499122177;
int qpow(int a,int b) { int c=1;
for (;b;b>>=1,a=1ll*a*a%mod) if (b&1) c=1ll*c*a%mod;
return c;
}
int n,s[N],p[N],sp=0;
int a[2][N],b[2][N];
int main() {
n=read();
for (int i=1;i<=n;++i) s[i]=read()?mod-1:1;
a[0][L]=b[0][L]=1;
for (int i=1;i<=n;++i) {
int cur=i&1,pre=cur^1;
memset(a[cur],0,sizeof(a[cur])),memset(b[cur],0,sizeof(b[cur]));
p[i]=read();
for (int j=-sp;j<=sp;++j) {
a[cur][j+p[i]+L]=(a[cur][j+p[i]+L]+1ll*inv2*a[pre][j+L])%mod;
a[cur][j-p[i]+L]=(a[cur][j-p[i]+L]+1ll*s[i]*inv2%mod*a[pre][j+L])%mod;
b[cur][j+p[i]+L]=(b[cur][j+p[i]+L]+1ll*inv2*b[pre][j+L])%mod;
b[cur][j-p[i]+L]=(b[cur][j-p[i]+L]+1ll*inv2*b[pre][j+L])%mod;
}
sp+=p[i];
}
int inv=qpow(sp,mod-2),ans=0;
for (int i=-sp;i<sp;++i) {
int x=(1ll*a[n&1][i+L]*b[n&1][sp+L]%mod-1ll*b[n&1][i+L]*a[n&1][sp+L]%mod+mod)%mod;
int y=(1ll*(i+mod)*inv-1+mod)%mod;
ans=(ans+1ll*x*qpow(y,mod-2))%mod;
}
ans=1ll*ans*qpow(b[n&1][sp+L],mod-3)%mod;
printf("%d\n",ans);
return 0;
}