分析
考虑点分治,那么需要考虑如何拼合两条路径。
设根到 $x$ 的路径上的数是 $d_1$,根到 $y$ 的路径上的数是 $d_2$,根到 $y$ 的距离为 $dep$。那么若 $x\to y$ 这条路径合法,则需要满足:
$$
d_1\cdot10^{dep}+d_2\equiv0\pmod m
$$
因为 $\gcd(10,m)=1$,所以可以把两边同时除以 $10^{dep}$,得到:
$$
d_1+d_2\cdot10^{-dep}\equiv0\pmod m
$$
这样子两条路径就独立了,分别存下来后枚举 $y$ 统计即可。
代码
//It is made by M_sea
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <map>
#define re register
#define mp make_pair
typedef int mainint;
#define int long long
using namespace std;
inline int read() {
int X=0,w=1; char c=getchar();
while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
return X*w;
}
const int MAXN=100000+10;
int n,m;
struct Edge { int v,w,nxt; };
Edge e[MAXN<<1];
int head[MAXN],cnt=0;
inline void addEdge(int u,int v,int w) {
e[++cnt].v=v,e[cnt].w=w,e[cnt].nxt=head[u],head[u]=cnt;
}
int vis[MAXN];
int p[MAXN];
int sz[MAXN],f[MAXN],dep[MAXN];
int root,sum,ans,num;
map<int,int> s; // sum[d1] 为 d1 的数量
pair<int,int> dig[MAXN<<1]; // <d2, dep>
inline void getroot(int u,int fa) {
sz[u]=1,f[u]=0;
for (re int i=head[u];i;i=e[i].nxt) {
int v=e[i].v;
if (vis[v]||v==fa) continue;
getroot(v,u); sz[u]+=sz[v];
f[u]=max(f[u],sz[v]);
}
f[u]=max(f[u],sum-f[u]);
if (f[u]<f[root]) root=u;
}
inline void getdigit(int u,int fa,int d1,int d2,int d) {
if (d>=0) ++s[d1],dig[++num]=mp(d2,d);
for (re int i=head[u];i;i=e[i].nxt) {
int v=e[i].v,w=e[i].w;
if (v==fa||vis[v]) continue;
int d3=(d1+w*p[d+1])%m;
int d4=(d2*10+w)%m;
getdigit(v,u,d3,d4,d+1);
}
}
inline int exgcd(int a,int b,int& x,int& y) {
if (!b) { x=1,y=0; return a; }
int d=exgcd(b,a%b,x,y);
int z=x; x=y; y=z-a/b*y;
return d;
}
inline int inv(int a,int m) {
int x,y,d=exgcd(a,m,x,y);
return d==1?(x%m+m)%m:-1;
}
inline int calc(int u,int d) {
s.clear(); int rt=0; num=0;
if (d) getdigit(u,0,d%m,d%m,0);
else getdigit(u,0,0,0,-1);
for (re int i=1;i<=num;++i) {
int tmp=(-dig[i].first * inv(p[dig[i].second+1],m) % m + m) % m;
if (s.find(tmp)!=s.end()) rt+=s[tmp];
if (!d) rt+=!dig[i].first;
}
if (!d) rt+=s[0]; // 0 也要算进去
return rt;
}
inline void solve(int u) {
ans+=calc(u,0); vis[u]=1;
for (re int i=head[u];i;i=e[i].nxt) {
int v=e[i].v,w=e[i].w;
if (vis[v]) continue;
ans-=calc(v,w);
sum=sz[v],f[0]=n,root=0;
getroot(v,u);
solve(root);
}
}
mainint main() {
n=read(),m=read();
for (re int i=1,u,v,w;i<n;++i) {
u=read()+1,v=read()+1,w=read();
addEdge(u,v,w);
addEdge(v,u,w);
}
p[0]=1; for (re int i=1;i<=n;++i) p[i]=p[i-1]*10%m;
f[0]=sum=n; getroot(1,0); solve(root);
printf("%lld\n",ans);
return 0;
}
3 条评论
请问为什么要判0呢,题目不是说是非零嘛。还有if (!d) rt+=!dig[i].first;这条语句的作用是什么啊。谢谢!
其实判的不是 $0$ 而是 $m$ 的倍数,因为
getdigit
里面的d1
是取了模的。后面那个是统计单链的情况,下面那句也是统计单链。抱歉过了这么久才回复 ::QQ:Y.qq70::
好的蟹蟹啦!