Luogu

LOJ

分析

设 $f_i$ 为从 $i$ 点生命值变为 $0$ 点生命值的期望步数,$p_i$ 为生命值 $\geq i$ 时一回合扣掉 $i$ 点生命值的概率,那么有

$$ f_i=\begin{cases}\frac{1}{m+1}\sum_{j=0}^{i+1}p_jf_{i-j+1}+\frac{m}{m+1}\sum_{j=0}^ip_jf_{i-j}+1,&i<n\\\sum_{j=0}^np_jf_{i-j}+1,&i=n\end{cases} $$

$p_i$ 是基础文化课知识

$$ p_i={k\choose i}\left(\frac{1}{m+1}\right)^i\left(\frac{m}{m+1}\right)^{k-i} $$

直接高斯消元是 $\mathcal{O}(n^3)$ 的,只能获得 70 分。

但是这个矩阵很有性质。我们从上往下消,消到每一行时只会有 $3$ 个位置(包括常数项)有值,也就是只要消三列。最后一行时只会剩下一个数和常数项,可以解出 $f_n$,然后回代即可。这样子就是 $\mathcal{O}(n^2)$ 的了。

注意特判无解的情况和 $m=0$ 的情况。

代码

// ====================================
//   author: M_sea
//   website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define debug(...) fprintf(stderr,__VA_ARGS__)
using namespace std;
typedef long long ll;

int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int N=1500+10;
const int mod=1e9+7;
int qpow(int a,int b) { int c=1;
    for (;b;b>>=1,a=1ll*a*a%mod) if (b&1) c=1ll*c*a%mod;
    return c;
}

int n,s,m,k,p[N],a[N][N];

void Gauss() {
    for (int i=1;i<n;++i) {
        int inv=qpow(a[i][i],mod-2);
        for (int j=i+1;j<=n;++j) {
            int t=1ll*a[j][i]*inv%mod;
            a[j][i]=0;
            a[j][i+1]=(a[j][i+1]+1ll*a[i][i+1]*(mod-t))%mod;
            a[j][n+1]=(a[j][n+1]+1ll*a[i][n+1]*(mod-t))%mod;
        }
    }
    for (int i=n;i;--i) {
        a[i][n+1]=1ll*a[i][n+1]*qpow(a[i][i],mod-2)%mod;
        a[i-1][n+1]=(a[i-1][n+1]+1ll*a[i-1][i]*(mod-a[i][n+1]))%mod;
    }
}

int main() {
    int T=read();
    while (T--) {
        n=read(),s=read(),m=read(),k=read();
        if (!k||(!m&&k<=1)) { puts("-1"); continue; }
        if (!m) {
            int ans=0;
            while (s>0) s=min(s+1,n),s-=k,++ans;
            printf("%d\n",ans);
            continue;
        }
        memset(p,0,sizeof(p)),memset(a,0,sizeof(a));
        int x=qpow(m+1,mod-2),y=1ll*m*x%mod; p[0]=qpow(y,k);
        for (int i=1;i<=min(n,k);++i)
            p[i]=1ll*p[i-1]*(k-i+1)%mod*qpow(m,mod-2)%mod*qpow(i,mod-2)%mod;
        for (int i=1;i<n;++i) {
            a[i][i]=a[i][n+1]=mod-1;
            for (int j=0;j<=i+1;++j) a[i][i-j+1]=(a[i][i-j+1]+1ll*x*p[j])%mod;
            for (int j=0;j<=i;++j) a[i][i-j]=(a[i][i-j]+1ll*y*p[j])%mod;
        }
        a[n][n]=a[n][n+1]=mod-1;
        for (int j=0;j<=n;++j) a[n][n-j]=(a[n][n-j]+p[j])%mod;
        Gauss();
        printf("%d\n",a[s][n+1]);
    }
    return 0;
}
最后修改:2021 年 01 月 14 日 10 : 02 PM