51nod

分析

考虑根号分治。令 $B=\left\lceil\sqrt{n}\right\rceil$,则 $[1,B]$ 内的物品能取完,$[B+1,n]$ 内的物品无法取完。

先考虑 $[1,B]$ 内物品。设 $f_{i,j}$ 表示前 $i$ 种物品和为 $j$ 的方案数,容易写出转移
$$
f_{i,j}=\sum_{k=0}^if_{i-1,j-ik}
$$
注意到 $(i,j)$ 的转移点 $(i,k)$ 满足 $k\leq j\land j\equiv k\pmod i$,因此可以开个桶记一下所有满足条件的 $k$ 的 $dp_{i,k}$ 之和。就可以做到 $\mathcal{O}(1)$ 转移了。

这一部分中,物品数不超过 $B$,和不超过 $n$,因此时间复杂度为 $\mathcal{O}(n\sqrt{n})$。

再考虑 $[B+1,n]$ 中物品。直接完全背包复杂度过高,需要考虑一些别的做法。

考虑我们要求的东西的本质是 $n$ 拆成一些 $[B+1,n]$ 中的数的方案数,因此可以考虑整数拆分模型。

设 $g_{i,j}$ 表示选了 $i$ 个数和为 $j$ 的方案数,转移有两种:

  • 加一个 $B+1$ 进来:$dp_{i+1,j+B+1}\leftarrow dp_{i,j}$。
  • 选出的所有数加上 $1$:$dp_{i,j+i}\leftarrow dp_{i,j}$。

这一部分中,拆出的数的个数不超过 $\frac{n}{B}$,数的个数不超过 $n-B$,因此时间复杂度为 $\mathcal{O}(n\sqrt{n})$。

最后只需要将两种情况合并即可。总时间复杂度 $\mathcal{O}(n\sqrt{n})$。

注意 $f$ 和 $g$ 需要滚动掉一个,不然空间开不下。

代码

// ===================================
//   author: M_sea
//   website: http://m-sea-blog.com/
// ===================================
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <vector>
#include <cmath>
#define re register
using namespace std;

inline int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int mod=23333333;
const int S=317+10,N=100000+10;

int n,B;
int f[2][N],sum[N],g[S][N],sf[N],sg[N];

int main() {
    n=read(),B=ceil(sqrt(n))+0.5;
    f[0][0]=1;
    for (re int i=1;i<=B;++i) {
        int now=i&1,pre=now^1;
        memset(sum,0,sizeof(sum));
        for (re int j=0;j<=n;++j) {
            sum[j%i]=(sum[j%i]+f[pre][j])%mod;
            f[now][j]=sum[j%i];
            if (j>=i*i)
                sum[j%i]=(sum[j%i]+mod-f[pre][j-i*i])%mod;
        }
    }
    for (re int i=0;i<=n;++i) sf[i]=f[B&1][i];
    g[0][0]=1;
    for (re int i=0;i<=B;++i)
        for (re int j=0;j<=n;++j) {
            if (i&&i+j<=n) g[i][i+j]=(g[i][i+j]+g[i][j])%mod;
            if (j+B+1<=n) g[i+1][j+B+1]=(g[i+1][j+B+1]+g[i][j])%mod;
        }
    for (re int i=0;i<=B;++i)
        for (re int j=0;j<=n;++j)
            sg[j]=(sg[j]+g[i][j])%mod;
    int ans=0;
    for (re int i=0;i<=n;++i) ans=(ans+1ll*sf[i]*sg[n-i])%mod;
    printf("%d\n",ans);
    return 0;
}
最后修改:2021 年 03 月 24 日 02 : 52 PM