分析
考虑一个区间怎样才是满足条件的。
设 $A_{l,r}$ 为端点都在 $[l,r]$ 内的 exciting 的路径数,$B_{l,r}$ 为端点都不在 $[l,r]$ 内的 exciting 的路径数。则 $[l,r]$ 满足条件即 $A>B$。
现在我们要比大小,根据文化课那一套理论可以想到作差。设 $C_{l,r}$ 为一个端点在 $[l,r]$ 内、一个端点不在 $[l,r]$ 内的 exciting 的路径数,$cnt_i$ 为以 $i$ 为端点的 exciting 的路径数,那么可以得到以下两个式子
$$
\begin{aligned}2A+C&=\sum_{i\in[l,r]}cnt_i\\2B+C&=\sum_{i\notin[l,r]}cnt_i\end{aligned}
$$
上下相减得到
$$
A-B=\sum_{i\in[l,r]}cnt_i-\sum_{i\notin[l,r]}cnt_i
$$
那么问题变为求出满足 $\sum_{i\in[l,r]}cnt_i-\sum_{i\notin[l,r]}cnt_i>0$ 的区间的数量。从小到大枚举右端点,可以发现最小的合法左端点是单增的,直接用一个指针维护一下即可。
剩下的问题就只有求 $cnt_i$ 了,这个直接点分治就好了。
代码
// ===================================
// author: M_sea
// website: http://m-sea-blog.com/
// ===================================
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <vector>
#include <cmath>
#define re register
using namespace std;
typedef long long ll;
inline int read() {
int X=0,w=1; char c=getchar();
while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
return X*w;
}
const int N=100000+10;
int n,a[N],cnt[N];
struct edge { int v,nxt; } e[N<<1];
int head[N];
inline void addEdge(int u,int v) {
static int c=0;
e[++c]=(edge){v,head[u]},head[u]=c;
}
int rt,sumsz,minsz,sz[N];
int vis[N],sum[N],o[N<<1],sta[N],top=0;
inline void getroot(int u,int fa) {
sz[u]=1; int res=0;
for (re int i=head[u];i;i=e[i].nxt) {
int v=e[i].v; if (v==fa||vis[v]) continue;
getroot(v,u); sz[u]+=sz[v],res=max(res,sz[v]);
}
res=max(res,sumsz-sz[u]);
if (res<minsz) rt=u,minsz=res;
}
inline void dfs(int u,int fa) {
sta[++top]=u,++o[sum[u]+N];
for (re int i=head[u];i;i=e[i].nxt) {
int v=e[i].v; if (v==fa||vis[v]) continue;
sum[v]=sum[u]+a[v],dfs(v,u);
}
}
inline void calc(int u,int sum0,int ta,int w) {
top=0,sum[u]=sum0+a[u],dfs(u,0);
for (re int i=1;i<=top;++i)
cnt[sta[i]]+=w*o[ta-sum[sta[i]]+N];
for (re int i=1;i<=top;++i) --o[sum[sta[i]]+N];
}
inline void solve(int u) {
vis[u]=1,calc(u,0,a[u],1);
for (re int i=head[u];i;i=e[i].nxt) {
int v=e[i].v; if (vis[v]) continue;
calc(v,a[u],a[u],-1);
rt=0,minsz=n,sumsz=sz[u],getroot(v,0); solve(rt);
}
}
int main() {
n=read();
for (re int i=1;i<=n;++i) a[i]=read()*2-1;
for (re int i=1;i<n;++i) {
int u=read(),v=read();
addEdge(u,v),addEdge(v,u);
}
rt=0,sumsz=minsz=n,getroot(1,0); solve(rt);
ll scnt=0,sin=0,ans=0;
for (re int i=1;i<=n;++i) scnt+=cnt[i];
for (re int r=1,l=1;r<=n;++r) {
sin+=cnt[r];
while (l<r&&sin>scnt-sin) sin-=cnt[l],++l;
ans+=l-1;
}
printf("%lld\n",ans);
return 0;
}