分析
先考虑第一个条件。设 $c_i$ 为经过 $i$ 的 $s\to t$ 的最短路数,那么应该有 $c_a+c_b=c_t$。
再考虑第二个条件。我们可以对每个点求出最短路图上不能到达它的且它不能到达的点集,然后存到一个 bitset
里。这可以直接拓扑排序求出。
枚举 $a$,合法的 $b$ 应该满足 $c_t-c_a=c_b$。开一个 map
存下每个 $c$ 对应的点集,再和之前预处理的东西求交就可以求出合法的 $b$ 的数量了。
这个 $c$ 可能很大,需要哈希一下(直接自然溢出啥事没有)。
代码
// ====================================
// author: M_sea
// website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define debug(...) fprintf(stderr,__VA_ARGS__)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
int read() {
int X=0,w=1; char c=getchar();
while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
return X*w;
}
const int N=50000+10,M=50000+10;
const ll inf=0x3f3f3f3f3f3f3f3f;
int n,m,S,T;
struct edge { int v,w,nxt; } e[M<<1];
int head[N];
void addEdge(int u,int v,int w) {
static int cnt=0;
e[++cnt]=(edge){v,w,head[u]},head[u]=cnt;
}
struct node { int u; ll d; };
bool operator <(node a,node b) { return a.d>b.d; }
ll disS[N],disT[N];
ull cntS[N],cntT[N],cnt[N];
void dijkstra(int s,ll *dis,ull *cnt) {
memset(dis,0x3f,(n+1)<<3); dis[s]=0,cnt[s]=1;
priority_queue<node> Q; Q.push((node){s,0});
while (!Q.empty()) {
int u=Q.top().u; ll d=Q.top().d; Q.pop();
if (d!=dis[u]) continue;
for (int i=head[u];i;i=e[i].nxt) {
int v=e[i].v,w=e[i].w;
if (dis[u]+w<dis[v])
dis[v]=dis[u]+w,cnt[v]=cnt[u],Q.push((node){v,dis[v]});
else if (dis[u]+w==dis[v]) cnt[v]+=cnt[u];
}
}
}
bitset<N> rS[N],rT[N];
int deg[N];
bool check(int s,int u,int v,int w) {
if (s==S) return disS[u]+disT[v]+w==disS[T];
else return disT[u]+disS[v]+w==disT[S];
}
void topsort(int s,bitset<N> *r) {
for (int i=1;i<=n;++i) r[i].set(),r[i].reset(i);
for (int u=1;u<=n;++u)
for (int i=head[u];i;i=e[i].nxt) {
int v=e[i].v,w=e[i].w;
if (check(s,u,v,w)) ++deg[v];
}
queue<int> Q;
for (int i=1;i<=n;++i) if (!deg[i]) Q.push(i);
while (!Q.empty()) {
int u=Q.front(); Q.pop();
debug("%d\n",u);
for (int i=head[u];i;i=e[i].nxt) {
int v=e[i].v,w=e[i].w;
if (check(s,u,v,w)) {
r[v]&=r[u];
if (!--deg[v]) Q.push(v);
}
}
}
}
map<ull,bitset<N>> H;
int main() {
n=read(),m=read(),S=read(),T=read();
for (int i=1;i<=m;++i) {
int u=read(),v=read(),w=read();
addEdge(u,v,w),addEdge(v,u,w);
}
dijkstra(S,disS,cntS),dijkstra(T,disT,cntT);
for (int i=1;i<=n;++i)
if (disS[i]+disT[i]==disS[T]) cnt[i]=cntS[i]*cntT[i];
topsort(S,rS),topsort(T,rT);
for (int i=1;i<=n;++i) H[cnt[i]].set(i);
ll ans=0;
for (int i=1;i<=n;++i)
ans+=(H[cnt[T]-cnt[i]]&rS[i]&rT[i]).count();
printf("%lld\n",ans>>1);
return 0;
}