Codeforces

分析

设 $E(x)$ 为游戏结束时所有饼干都在 $x$ 手中的期望次数,$E'(x)$ 为游戏结束当且仅当所有饼干都在 $x$ 手中时的期望次数,$P(x)$ 为游戏结束时所有饼干都在 $x$ 手中的概率,$C$ 为所有饼干都在某个人手中时全部到另一个人手中的期望次数,则有

$$ E(x)=E'(x)-\sum_{1\leq i\leq n,i\neq x}\left(P(i)\times C+E(i)\right) $$

$$ \sum_{i=1}^nE(i)=E'(x)-C\sum_{1\leq i\leq n,i\neq x}P(i) $$

对所有 $n$ 个这样的式子求和,有

$$ n\times ans=\sum_{i=1}^n E'(i)-C(n-1) $$

那么我们只要求出 $E'(i)$ 和 $C$。

设 $f_i$ 为一个人手上有 $i$ 块饼干时拿到所有饼干的期望次数,则 $E'(i)=f_{a_i},C=f_0$。

设 $s=\sum_{i=1}^n a_i$,则:

$$ f_i=\begin{cases}\frac{s-i}{s}\left(\frac{1}{n-1}f_{i+1}+\frac{n-2}{n-1}f_i\right)+\frac{i}{s}f_{i-1}+1,&0<i<s\\\frac{n-2}{n-1}f_i+\frac{1}{n-1}f_{i+1}+1,&i=0\\0,&i=s\end{cases} $$

因为有

$$ f_0=\frac{n-2}{n-1}f_0+\frac{1}{n-1}f_1+1 $$

可以解出

$$ f_0=f_1+n-1 $$

接下来求剩下的项。直接消元可能会有除以 $0$ 的情况出现,我们考虑一些其它的方法。

设 $g_i=f_i-f_{i+1}$,则 $f_i=\sum_{j=i}^sg_j$。

那么

$$ \sum_{j=i}^sg_j=\frac{s-i}{s}\left(\frac{1}{n-1}\sum_{j=i+1}^sg_j+\frac{n-2}{n-1}\sum_{j=i}^sg_j\right)+\frac{i}{s}\sum_{j=i-1}^sf_j+1 $$

注意到对于 $j>i$ 的项,左右两边系数都为 $1$,所以可以消去,得到

$$ g_i=\frac{(s-i)(n-2)}{s(n-1)}g_i+\frac{i}{s}(g_{i-1}+g_i)+1 $$

$$ g_i=\frac{s(n-1)+i(n-1)g_{i-1}}{s-i} $$

又因为 $g_0=f_0-f_1=n-1$,所以我们可以递推求出所有 $g_i$,从而求出所有 $f_i$,然后算答案即可。

代码

// ====================================
//   author: M_sea
//   website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin); freopen(#x".out","w",stdout)
#define debug(...) fprintf(stderr,__VA_ARGS__)
using namespace std;
typedef long long ll;

int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int N=300000+10;
const int mod=998244353;
int qpow(int a,int b) { int c=1;
    for (;b;b>>=1,a=1ll*a*a%mod) if (b&1) c=1ll*c*a%mod;
    return c;
}

int n,s,a[N];
int f[N],g[N];

int main() {
    n=read();
    for (int i=1;i<=n;++i) a[i]=read(),s+=a[i];
    g[0]=n-1;
    for (int i=1;i<s;++i)
        g[i]=(1ll*s*(n-1)+1ll*i*(n-1)%mod*g[i-1])%mod*qpow(s-i,mod-2)%mod;
    for (int i=s;~i;--i) f[i]=(f[i+1]+g[i])%mod;
    int ans=0;
    for (int i=1;i<=n;++i) ans=(ans+f[a[i]])%mod;
    ans=(ans-1ll*f[0]*(n-1)%mod+mod)%mod;
    ans=1ll*ans*qpow(n,mod-2)%mod;
    printf("%d\n",ans);
    return 0;
}
最后修改:2020 年 09 月 28 日 09 : 18 AM