Luogu

UOJ

分析

求出 $s$ 的所有周期 $l_1, l_2, \cdots, l_k$,那么我们相当于要求有多少个 $x \in [0, w - n]$ 满足存在一组 $\{x_i \geq 0\}$ 使得 $\sum_{i = 1}^k x_i l_i = x$。

发现如果 $x$ 可以,那么 $x + n$ 也一定可以,于是我们可以在模 $n$ 意义下跑同余类最短路(这里最短路的含义是最小的能被表出的模 $n$ 为 $r$ 的数),复杂度 $\mathcal{O}(Tn^2)$。

然后有这样一个性质:

性质 1 一个字符串 $s$ 的所有 border 可以划分为 $\mathcal{O}(\log |s|)$ 个等差数列。

证明见 https://www.luogu.com.cn/blog/sysjuruo/solution-p4156

构造也很简单,我们只需要按照长度排序,每次取出开头极长的一段等差数列即可。

我们对于每个等差数列单独考虑,假设它们为 $x, x + d, x + 2d, \cdots, x + ld$,那么我们可以在模 $x$ 意义下跑同余类最短路,相当于从 $y$ 向 $(y + d) \bmod x$ 连边权为 $d$ 的边,然后更新最短路。

这些边显然构成若干个环,我们单独考虑每个环。显然环上 $dis$ 最小的点是不会被更新的,因此我们直接从其出发绕环一周更新即可。但是还有 $l$ 的限制,即我们只能从前面的 $l$ 个位置转移,因此需要用一个单调队列维护。

在考虑另外一个等差数列的时候,我们需要把模 $x$ 的最短路变为模 $y$ 的最短路。不妨设原最短路为 $f$,新最短路为 $g$,那么我们首先用 $f_i$ 更新 $g_{f_i \bmod y}$;然后原来还有长度为 $x$ 的边,因此我们还需要连 $k \to (k + x) \bmod y$ 后更新最短路,同样从最小值出发绕环一周更新即可。

复杂度 $\mathcal{O}(Tn\log n)$。

代码

UOJ 上的 Extra Test T 掉了 /dk

// ====================================
//   author: M_sea
//   website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in", "r", stdin), freopen(#x".out", "w", stdout)
#define debug(...) fprintf(stderr, __VA_ARGS__)
using namespace std;
typedef long long ll;

ll read() {
    ll X = 0, w = 1; char c = getchar();
    while (c < '0' || c > '9') { if (c == '-') w = -1; c = getchar(); }
    while (c >= '0' && c <= '9') X = X * 10 + c - '0', c = getchar();
    return X * w;
}

const int N = 1000000 + 10;
const ll inf = 0x3f3f3f3f3f3f3f3f;

int n; ll W;
char s[N];

int nxt[N], len[N], clen;
void GetNext() {
    clen = 0;
    for (int i = 2, j = 0; i <= n; ++i) {
        while (j && s[j + 1] != s[i]) j = nxt[j];
        if (s[j + 1] == s[i]) ++j;
        nxt[i] = j;
    }
    for (int i = nxt[n]; i; i = nxt[i]) len[++clen] = n - i;
    len[++clen] = n;
}

int lp;
ll dis[N];
void ChangeMod(int np) {
    static ll tmp[N];
    static int sta[N], top;
    for (int i = 0; i < np; ++i)
        tmp[i] = dis[i], dis[i] = inf;
    for (int i = 0; i < np; ++i)
        dis[tmp[i] % np] = min(dis[tmp[i] % np], tmp[i]);
    int d = __gcd(lp, np);
    for (int i = 0; i < d; ++i) {
        sta[top = 1] = i;
        for (int j = (i + lp) % np; j != i; j = (j + lp) % np)
            sta[++top] = j;
        int mnp = 1;
        for (int j = 2; j <= top; ++j)
            if (dis[sta[j]] < dis[sta[mnp]]) mnp = j;
        rotate(sta + 1, sta + mnp, sta + top + 1);
        for (int j = 2; j <= top; ++j)
            dis[sta[j]] = min(dis[sta[j]], dis[sta[j - 1]] + lp);
    }
    lp = np;
}
void Solve(int x, int dlt, int l) {
    static int sta[N], top, h, t;
    static pair<ll, int> Q[N];
    ChangeMod(x);
    int d = __gcd(x, dlt);
    for (int i = 0; i < d; ++i) {
        sta[top = 1] = i;
        for (int j = (i + dlt) % x; j != i; j = (j + dlt) % x)
            sta[++top] = j;
        int mnp = 1;
        for (int j = 2; j <= top; ++j)
            if (dis[sta[j]] < dis[sta[mnp]]) mnp = j;
        rotate(sta + 1, sta + mnp, sta + top + 1);
        Q[h = t = 1] = make_pair(dis[sta[1]] - dlt, 1);
        for (int j = 2; j <= top; ++j) {
            while (h <= t && Q[h].second + l < j) ++h;
            if (h <= t)
                dis[sta[j]] = min(dis[sta[j]], x + Q[h].first + 1ll * j * dlt);
            while (h <= t && Q[t].first >= dis[sta[j]] - 1ll * j * dlt) --t;
            Q[++t] = make_pair(dis[sta[j]] - 1ll * j * dlt, j);
        }
    }
}

int main() {
    int T = read();
    while (T--) {
        n = read(), W = read() - n;
        scanf("%s", s + 1);
        GetNext();
        lp = n, dis[0] = 0;
        for (int i = 1; i < n; ++i) dis[i] = inf;
        for (int i = 1, j = 1; i < clen; i = j) {
            while (j < clen && len[j + 1] - len[j] == len[i + 1] - len[i]) ++j;
            Solve(len[i], len[i + 1] - len[i], j - i - 1);
        }
        ll ans = 0;
        for (int i = 0; i < lp; ++i)
            if (dis[i] <= W) ans += (W - dis[i]) / lp + 1;
        printf("%lld\n", ans);
    }
    return 0;
}
最后修改:2021 年 04 月 07 日 11 : 23 AM