分析
贪心+九条可怜。
设数 $i$ 的个数为 $cnt[i]$ ,于是期望轮数为 $\frac{n!}{\prod\limits_{i=1}^kcnt[i]!}$。
显然 $cnt[i]$ 尽量平均会使得这个东西最大。
考虑将所有 $[l,r]$ 内的 $cnt$ 排序,然后每次将前面的抬高到当前的高度。显然这样子可以达到尽量的平均。
然后 $cnt$ 中可能会有很多$0$,把这些$0$单独拿出来放在前面。
代码
//It is made by M_sea
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
#define re register
using namespace std;
inline int read() {
int X=0,w=1; char c=getchar();
while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
return X*w;
}
const int N=200000+10;
const int M=10200000+10;
const int mod=998244353;
int fact[M],inv[M];
int a[N],b[N],cnt[N],sum[N];
inline int fastpow(int a,int b,int ans=1) {
for (;b;b>>=1,a=1ll*a*a%mod)
if (b&1) ans=1ll*ans*a%mod;
return ans;
}
inline void init(int n) {
fact[0]=1;
for (re int i=1;i<=n;++i) fact[i]=1ll*fact[i-1]*i%mod;
inv[n]=fastpow(fact[n],mod-2);
for (re int i=n-1;i>=0;--i) inv[i]=1ll*inv[i+1]*(i+1)%mod;
}
int main() {
init(10200000);
int T=read();
while (T--) {
//输入+排序+去重
int n=read(),m=read(),l=read(),r=read(),ans=1,all=fact[n+m];
for (re int i=1;i<=n;++i) a[i]=b[i]=read();
sort(b+1,b+n+1); int top=unique(b+1,b+n+1)-b-1;
//初始化
memset(cnt,0,sizeof(int)*(n+1));
memset(sum,0,sizeof(int)*(n+1));
//求s
int s=r-l+1;
for (re int i=1;i<=top;++i)
if (b[i]>=l&&b[i]<=r) --s;
//求cnt
for (re int i=1;i<=n;++i) {
a[i]=lower_bound(b+1,b+top+1,a[i])-b;
if (b[a[i]]>=l&&b[a[i]]<=r) ++cnt[a[i]];
++sum[a[i]];
}
//初始化ans
for (re int i=1;i<=top;++i) ans=1ll*ans*fact[sum[i]]%mod;
//移位
sort(cnt+1,cnt+top+1); int pos=0;
for (re int i=1;i<=top;++i) if (!cnt[i]) pos=i;
for (re int i=1;i<=top-pos;++i) cnt[i]=cnt[i+pos];
n=top-pos;
//求解
for (re int i=0;i<n;++i) {
if (cnt[i+1]==cnt[i]) continue;
if (1ll*(cnt[i+1]-cnt[i])*(i+s)<=1ll*m) {
m-=(cnt[i+1]-cnt[i])*(i+s);
int tmp=1ll*fact[cnt[i+1]]*inv[cnt[i]]%mod;
ans=1ll*ans*fastpow(tmp,i+s)%mod;
} else {
int x=m/(i+s),tmp=1ll*fact[cnt[i]+x]*inv[cnt[i]]%mod;
ans=1ll*ans*fastpow(tmp,i+s)%mod;
ans=1ll*ans*fastpow(cnt[i]+x+1,m%(i+s))%mod;
m=0; break;
}
}
if (m>0) {
int x=m/(r-l+1),tmp=1ll*fact[cnt[n]+x]*inv[cnt[n]]%mod;
ans=1ll*ans*fastpow(tmp,r-l+1)%mod;
ans=1ll*ans*fastpow(cnt[n]+x+1,m%(r-l+1))%mod;
}
//输出
printf("%d\n",1ll*all*fastpow(ans,mod-2)%mod);
}
return 0;
}