Luogu

BZOJ

分析

下文中默认字符串的下标从 $0$ 开始。

可以发现答案等于「位置和字符都关于某条对称轴对称的子序列」的数量减去「回文子串」的数量。

后面的就是 manacher 板子,考虑怎么求前面的。

设 $c_i$ 表示「位置和字符都关于对称轴 $i$ 对称的子序列」的数量,则有
$$
c_i=\sum_{j=0}^i[s_j=s_{2\times i-j}]
$$
把它拆一下得到
$$
c_i=\left(\sum_{j=0}^i[s_j=a][s_{2\times i-j}=a]\right)+\left(\sum_{j=0}^i[s_j=b][s_{2\times i-j}=b]\right)
$$
显然是卷积的形式。因此构造多项式
$$
F_i=[s_i=a],G_i=[s_i=b]
$$
后就可以用 FFT 求出 $c_i$ 了。每条对称轴 $i$ 的贡献则为 $2^{c_i}-1$,累加起来即可。

注意到 FFT 的运算过程中数不会太大,因此可以写 NTT,不会出锅。

代码

// ===================================
//   author: M_sea
//   website: http://m-sea-blog.com/
// ===================================
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <vector>
#include <cmath>
#define re register
using namespace std;

inline int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int N=300000+10;
const int Nmod=998244353,mod=1000000007;
inline int qpow(int a,int b,int m) { int c=1;
    for (;b;b>>=1,a=1ll*a*a%m) if (b&1) c=1ll*c*a%m;
    return c;
}

int n; char s[N];
int F[N],G[N],H[N];

int r[N];
inline void NTT(int* A,int n,int op) {
    for (re int i=0;i<n;++i) if (i<r[i]) swap(A[i],A[r[i]]);
    for (re int i=1;i<n;i<<=1) {
        int rot=qpow(op==1?3:332748118,(Nmod-1)/(i<<1),Nmod);
        for (re int j=0;j<n;j+=i<<1)
            for (re int k=0,w=1;k<i;++k,w=1ll*w*rot%Nmod) {
                int x=A[j+k],y=1ll*w*A[j+k+i]%Nmod;
                A[j+k]=(x+y)%Nmod,A[j+k+i]=(x-y+Nmod)%Nmod;
            }
    }
    if (op==-1) { int inv=qpow(n,Nmod-2,Nmod);
        for (re int i=0;i<n;++i) A[i]=1ll*A[i]*inv%Nmod;
    }
}

int f[N]; char a[N];
inline int manacher() {
    a[0]='?',a[1]='#';
    for (re int i=1;i<=n;++i) a[i<<1]=s[i-1],a[i<<1|1]='#';
    n=n<<1|1,a[n+1]='!';
    for (re int i=1,mr=1,mid;i<=n;++i) {
        f[i]=i<mr?min(f[(mid<<1)-i],f[mid]+mid-i):1;
        while (a[i-f[i]]==a[i+f[i]]) ++f[i];
        if (i+f[i]>mr) mr=i+f[i],mid=i;
    }
    int res=0;
    for (re int i=1;i<=n;++i) res=(res+(f[i]>>1))%mod;
    return res;
}

int main() {
    scanf("%s",s); n=strlen(s);
    int lim=1,l=0;
    for (;lim<=n+n;lim<<=1,++l);
    for (re int i=0;i<lim;++i) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    for (re int i=0;i<n;++i) F[i]=(s[i]=='a'),G[i]=(s[i]=='b');
    NTT(F,lim,1),NTT(G,lim,1);
    for (re int i=0;i<lim;++i)
        F[i]=1ll*F[i]*F[i]%Nmod,G[i]=1ll*G[i]*G[i]%Nmod;
    NTT(F,lim,-1),NTT(G,lim,-1);
    for (re int i=0;i<lim;++i) H[i]=(F[i]+G[i])%mod;
    for (re int i=0;i<lim;++i) H[i]=(H[i]+1)>>1;
    int ans=0;
    for (re int i=0;i<lim;++i) ans=(ans+qpow(2,H[i],mod)-1)%mod;
    printf("%d\n",(ans-manacher()+mod)%mod);
    return 0;
}
最后修改:2021 年 03 月 24 日 02 : 56 PM