Luogu

BZOJ

分析

广义 $\mathrm{SAM}$ 。

首先,树上路径一定是某叶子节点到另一叶子节点的路径的子段。所以

然后这题叶子数很少,直接对每个叶子 $\mathrm{dfs}$ 后建广义 $\mathrm{SAM}$ 即可

如果只有一个串要求不同子串个数的话,那么答案就是 $\sum len[i]-len[fa[i]]$ 。

然后现在可能有多个串,那么把 $\mathrm{SAM}$ 换成 广义 $\mathrm{SAM}$ ,答案还是这个东西。

建广义 $\mathrm{SAM}$ 的话,每个串重新从根节点开始插入就行了。

代码

// =================================
//   author: M_sea
//   website: http://m-sea-blog.com/
// =================================
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
#define re register
typedef long long ll;
using namespace std;

inline int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int N=1000000+10;

struct Suffix_Automaton {
    int last,tot;
    int ch[N<<1][10],fa[N<<1],len[N<<1];

    Suffix_Automaton() { last=tot=1; }
    inline int extend(int p,int c) {
        int np=++tot; last=np,len[np]=len[p]+1;
        for (;p&&!ch[p][c];p=fa[p]) ch[p][c]=np;
        if (!p) fa[np]=1;
        else {
            int q=ch[p][c];
            if (len[p]+1==len[q]) fa[np]=q;
            else {
                int nq=++tot; len[nq]=len[p]+1;
                memcpy(ch[nq],ch[q],sizeof(ch[q]));
                fa[nq]=fa[q],fa[q]=fa[np]=nq;
                for (;ch[p][c]==q;p=fa[p]) ch[p][c]=nq;
            }
        }
        return np;
    }
    inline ll query() {
        ll ans=0;
        for (re int i=1;i<=tot;++i) ans+=len[i]-len[fa[i]];
        return ans;
    }
} S;

int n,c,a[N];
struct Edge { int v,nxt; } e[N<<1];
int head[N],deg[N];

inline void addEdge(int u,int v) {
    static int cnt=0;
    e[++cnt]=(Edge){v,head[u]},head[u]=cnt;
}

inline void dfs(int u,int fa,int p) {
    p=S.extend(p,a[u]);
    for (re int i=head[u];i;i=e[i].nxt)
        if (e[i].v!=fa) dfs(e[i].v,u,p);
}

int main() {
    n=read(),c=read();
    for (re int i=1;i<=n;++i) a[i]=read();
    for (re int i=1;i<n;++i) {
        int u=read(),v=read();
        addEdge(u,v),addEdge(v,u);
        ++deg[u],++deg[v];
    }
    for (re int i=1;i<=n;++i)
        if (deg[i]==1) S.last=1,dfs(i,0,1);
    printf("%lld\n",S.query());
    return 0;
}
最后修改:2019 年 09 月 26 日 01 : 04 PM