Luogu

LOJ

分析

先考虑什么样的排列是满足条件的。冷静分析一下,达到下界当且仅当排列中没有长度 $\geq 3$ 的下降子序列,因为如果有的话一定有一个元素需要移过目标位置再移回来,从而不能达到下界。

先考虑忽略字典序的限制怎么做。设 $dp_{i,j}$ 表示长度为 $i$、首项为 $j$ 的合法排列数,转移有三种情况

  • 第二个元素 $k$ 比 $j$ 大:那么如果 $k$ 不能与后面的元素构成长度为 $3$ 的下降子序列,则 $j$ 也不能。所以我们可以把第二个元素往后所有比 $j$ 大的数全部减去 $1$,变成一个长 $i-1$、首项为 $k-1$ 的排列,所以这部分的方案数即为 $\sum_{k=j+1}^idp_{i-1,k-1}$。
  • 第二个元素 $k$ 比 $j$ 小且不等于 $1$:那么 $j,k,1$ 一定会构成长度为 $3$ 的下降子序列,所以方案数为 $0$。
  • 第二个元素 $k$ 等于 $1$:那么 $[1,j-1]$ 的位置必须是单增的。考虑把 $1$ 变成 $2$、$2$ 变成 $3$、……、$j-1$ 变成 $1$,那么变化前不满足条件的两个数在变化后一定会和 $j-1$ 构成长度为 $3$ 的下降子序列,所以变化后合法排列数即为方案数,即方案数为 $dp_{i-1,j-1}$。

综上可以得到状态转移方程
$$
dp_{i,j}=\sum_{k=j-1}^{i-1}dp_{i-1,k}
$$
我们先不管这个 DP 是 $\mathcal{O}(n^2)$ 的问题,来考虑一下字典序怎么处理。

考虑枚举 LCP 的长度 $i-1$,那么 $[i,n]$ 对应一个首项比原排列的第 $i$ 位大的排列。设原排列中 $[i,n]$ 中有 $j$ 个 $\leq[1,i]$ 中最大值的数(可以通过树状数组预处理得到),那么 $i$ 填的数必须比前缀最大值大,从而方案数为
$$
\sum_{k=j+1}^{n-i+1}dp_{n-i+1,k}
$$
可以发现是一个后缀和的形式。不妨设 $S(i,j)=\sum_{k=j}^i dp_{i,k}$,观察之前那个状态转移方程可以得到
$$
S(i,j)=S(i-1,j-1)+S(i,j+1)
$$
考虑 $S(n,m)$ 的实际意义,可以看成从 $(0,0)$ 出发,每次可以向右上或者向下走一步,不能走到直线 $y=-1$ 上,走到 $(n,m)$ 的方案数。如果我们在往下走时也往右走一步,那么就变成走到 $(2n-m,m)$ 的方案数。我们首先在 $2n-m$ 步中选出 $n-m$ 步向下的,即 $2n-m\choose n-m$;然后还要减掉经过 $y=-1$ 的方案数,可以考虑把第一次碰到 $y=-1$ 前的直线对称一下,相当于从 $(0,-2)$ 走到 $(2n-m,m)$,所以方案数为 $2n-m\choose n-m-1$。综上有
$$
S(n,m)={2n-m\choose n-m}-{2n-m\choose n-m-1}
$$
枚举 LCP 时还需要考虑一定不合法的情况,即前缀最大值、前缀最大值到当前位置的次大值和后缀最小值构成长度为 $3$ 的下降子序列的情况,直接维护一下即可。

代码

// ====================================
//   author: M_sea
//   website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin); freopen(#x".out","w",stdout)
using namespace std;
typedef long long ll;

int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int N=1200000+10;
const int mod=998244353;
int qpow(int a,int b) { int c=1;
    for (;b;b>>=1,a=1ll*a*a%mod) if (b&1) c=1ll*c*a%mod;
    return c;
}

int fac[N],ifac[N];
void init(int n) {
    fac[0]=1;
    for (int i=1;i<=n;++i) fac[i]=1ll*fac[i-1]*i%mod;
    ifac[n]=qpow(fac[n],mod-2);
    for (int i=n;i;--i) ifac[i-1]=1ll*ifac[i]*i%mod;
}
int C(int n,int m) {
    if (n<m) return 0;
    return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}
int S(int n,int m) {
    if (n<m) return 0;
    return (C(n+n-m,n-m)-C(n+n-m,n-m-1)+mod)%mod;
}

int n,p[N],pre[N],suf[N],cnt[N];

int c[N];
void add(int x,int y) { for (;x<=n;x+=x&-x) c[x]+=y; }
int sum(int x) { int s=0; for (;x;x-=x&-x) s+=c[x]; return s; }

int main() {
    init(1200000);
    int T=read();
    while (T--) {
        memset(c,0,sizeof(c));
        n=read();
        for (int i=1;i<=n;++i) p[i]=read();
        pre[1]=p[1],suf[n]=p[n];
        for (int i=2;i<=n;++i) pre[i]=max(pre[i-1],p[i]);
        for (int i=n-1;i>0;--i) suf[i]=min(suf[i+1],p[i]);
        for (int i=n;i;--i) add(p[i],1),cnt[i]=sum(pre[i]);
        int ans=0;
        for (int i=1,mx2=0;i<n;++i) {
            if (mx2>suf[i]) break;
            ans=(ans+S(n-i+1,cnt[i]+1))%mod;
            if (p[i]<pre[i]) mx2=max(mx2,p[i]);
        }
        printf("%d\n",ans);
    }
    return 0;
}
最后修改:2020 年 08 月 27 日 11 : 19 AM