分析
求出 $s$ 的所有周期 $l_1, l_2, \cdots, l_k$,那么我们相当于要求有多少个 $x \in [0, w - n]$ 满足存在一组 $\{x_i \geq 0\}$ 使得 $\sum_{i = 1}^k x_i l_i = x$。
发现如果 $x$ 可以,那么 $x + n$ 也一定可以,于是我们可以在模 $n$ 意义下跑同余类最短路(这里最短路的含义是最小的能被表出的模 $n$ 为 $r$ 的数),复杂度 $\mathcal{O}(Tn^2)$。
然后有这样一个性质:
性质 1 一个字符串 $s$ 的所有 border 可以划分为 $\mathcal{O}(\log |s|)$ 个等差数列。
证明见 https://www.luogu.com.cn/blog/sysjuruo/solution-p4156。
构造也很简单,我们只需要按照长度排序,每次取出开头极长的一段等差数列即可。
我们对于每个等差数列单独考虑,假设它们为 $x, x + d, x + 2d, \cdots, x + ld$,那么我们可以在模 $x$ 意义下跑同余类最短路,相当于从 $y$ 向 $(y + d) \bmod x$ 连边权为 $d$ 的边,然后更新最短路。
这些边显然构成若干个环,我们单独考虑每个环。显然环上 $dis$ 最小的点是不会被更新的,因此我们直接从其出发绕环一周更新即可。但是还有 $l$ 的限制,即我们只能从前面的 $l$ 个位置转移,因此需要用一个单调队列维护。
在考虑另外一个等差数列的时候,我们需要把模 $x$ 的最短路变为模 $y$ 的最短路。不妨设原最短路为 $f$,新最短路为 $g$,那么我们首先用 $f_i$ 更新 $g_{f_i \bmod y}$;然后原来还有长度为 $x$ 的边,因此我们还需要连 $k \to (k + x) \bmod y$ 后更新最短路,同样从最小值出发绕环一周更新即可。
复杂度 $\mathcal{O}(Tn\log n)$。
代码
UOJ 上的 Extra Test T 掉了 /dk
// ====================================
// author: M_sea
// website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in", "r", stdin), freopen(#x".out", "w", stdout)
#define debug(...) fprintf(stderr, __VA_ARGS__)
using namespace std;
typedef long long ll;
ll read() {
ll X = 0, w = 1; char c = getchar();
while (c < '0' || c > '9') { if (c == '-') w = -1; c = getchar(); }
while (c >= '0' && c <= '9') X = X * 10 + c - '0', c = getchar();
return X * w;
}
const int N = 1000000 + 10;
const ll inf = 0x3f3f3f3f3f3f3f3f;
int n; ll W;
char s[N];
int nxt[N], len[N], clen;
void GetNext() {
clen = 0;
for (int i = 2, j = 0; i <= n; ++i) {
while (j && s[j + 1] != s[i]) j = nxt[j];
if (s[j + 1] == s[i]) ++j;
nxt[i] = j;
}
for (int i = nxt[n]; i; i = nxt[i]) len[++clen] = n - i;
len[++clen] = n;
}
int lp;
ll dis[N];
void ChangeMod(int np) {
static ll tmp[N];
static int sta[N], top;
for (int i = 0; i < np; ++i)
tmp[i] = dis[i], dis[i] = inf;
for (int i = 0; i < np; ++i)
dis[tmp[i] % np] = min(dis[tmp[i] % np], tmp[i]);
int d = __gcd(lp, np);
for (int i = 0; i < d; ++i) {
sta[top = 1] = i;
for (int j = (i + lp) % np; j != i; j = (j + lp) % np)
sta[++top] = j;
int mnp = 1;
for (int j = 2; j <= top; ++j)
if (dis[sta[j]] < dis[sta[mnp]]) mnp = j;
rotate(sta + 1, sta + mnp, sta + top + 1);
for (int j = 2; j <= top; ++j)
dis[sta[j]] = min(dis[sta[j]], dis[sta[j - 1]] + lp);
}
lp = np;
}
void Solve(int x, int dlt, int l) {
static int sta[N], top, h, t;
static pair<ll, int> Q[N];
ChangeMod(x);
int d = __gcd(x, dlt);
for (int i = 0; i < d; ++i) {
sta[top = 1] = i;
for (int j = (i + dlt) % x; j != i; j = (j + dlt) % x)
sta[++top] = j;
int mnp = 1;
for (int j = 2; j <= top; ++j)
if (dis[sta[j]] < dis[sta[mnp]]) mnp = j;
rotate(sta + 1, sta + mnp, sta + top + 1);
Q[h = t = 1] = make_pair(dis[sta[1]] - dlt, 1);
for (int j = 2; j <= top; ++j) {
while (h <= t && Q[h].second + l < j) ++h;
if (h <= t)
dis[sta[j]] = min(dis[sta[j]], x + Q[h].first + 1ll * j * dlt);
while (h <= t && Q[t].first >= dis[sta[j]] - 1ll * j * dlt) --t;
Q[++t] = make_pair(dis[sta[j]] - 1ll * j * dlt, j);
}
}
}
int main() {
int T = read();
while (T--) {
n = read(), W = read() - n;
scanf("%s", s + 1);
GetNext();
lp = n, dis[0] = 0;
for (int i = 1; i < n; ++i) dis[i] = inf;
for (int i = 1, j = 1; i < clen; i = j) {
while (j < clen && len[j + 1] - len[j] == len[i + 1] - len[i]) ++j;
Solve(len[i], len[i + 1] - len[i], j - i - 1);
}
ll ans = 0;
for (int i = 0; i < lp; ++i)
if (dis[i] <= W) ans += (W - dis[i]) / lp + 1;
printf("%lld\n", ans);
}
return 0;
}