LOJ

分析

考虑一个区间怎样才是满足条件的。

设 $A_{l,r}$ 为端点都在 $[l,r]$ 内的 exciting 的路径数,$B_{l,r}$ 为端点都不在 $[l,r]$ 内的 exciting 的路径数。则 $[l,r]$ 满足条件即 $A>B$。

现在我们要比大小,根据文化课那一套理论可以想到作差。设 $C_{l,r}$ 为一个端点在 $[l,r]$ 内、一个端点不在 $[l,r]$ 内的 exciting 的路径数,$cnt_i$ 为以 $i$ 为端点的 exciting 的路径数,那么可以得到以下两个式子
$$
\begin{aligned}2A+C&=\sum_{i\in[l,r]}cnt_i\\2B+C&=\sum_{i\notin[l,r]}cnt_i\end{aligned}
$$
上下相减得到
$$
A-B=\sum_{i\in[l,r]}cnt_i-\sum_{i\notin[l,r]}cnt_i
$$
那么问题变为求出满足 $\sum_{i\in[l,r]}cnt_i-\sum_{i\notin[l,r]}cnt_i>0$ 的区间的数量。从小到大枚举右端点,可以发现最小的合法左端点是单增的,直接用一个指针维护一下即可。

剩下的问题就只有求 $cnt_i$ 了,这个直接点分治就好了。

代码

// ===================================
//   author: M_sea
//   website: http://m-sea-blog.com/
// ===================================
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <vector>
#include <cmath>
#define re register
using namespace std;
typedef long long ll;

inline int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int N=100000+10;

int n,a[N],cnt[N];

struct edge { int v,nxt; } e[N<<1];
int head[N];
inline void addEdge(int u,int v) {
    static int c=0;
    e[++c]=(edge){v,head[u]},head[u]=c;
}

int rt,sumsz,minsz,sz[N];
int vis[N],sum[N],o[N<<1],sta[N],top=0;
inline void getroot(int u,int fa) {
    sz[u]=1; int res=0;
    for (re int i=head[u];i;i=e[i].nxt) {
        int v=e[i].v; if (v==fa||vis[v]) continue;
        getroot(v,u); sz[u]+=sz[v],res=max(res,sz[v]);
    }
    res=max(res,sumsz-sz[u]);
    if (res<minsz) rt=u,minsz=res;
}
inline void dfs(int u,int fa) {
    sta[++top]=u,++o[sum[u]+N];
    for (re int i=head[u];i;i=e[i].nxt) {
        int v=e[i].v; if (v==fa||vis[v]) continue;
        sum[v]=sum[u]+a[v],dfs(v,u);
    }
}
inline void calc(int u,int sum0,int ta,int w) {
    top=0,sum[u]=sum0+a[u],dfs(u,0);
    for (re int i=1;i<=top;++i)
        cnt[sta[i]]+=w*o[ta-sum[sta[i]]+N];
    for (re int i=1;i<=top;++i) --o[sum[sta[i]]+N];
}
inline void solve(int u) {
    vis[u]=1,calc(u,0,a[u],1);
    for (re int i=head[u];i;i=e[i].nxt) {
        int v=e[i].v; if (vis[v]) continue;
        calc(v,a[u],a[u],-1);
        rt=0,minsz=n,sumsz=sz[u],getroot(v,0); solve(rt);
    }
}

int main() {
    n=read();
    for (re int i=1;i<=n;++i) a[i]=read()*2-1;
    for (re int i=1;i<n;++i) {
        int u=read(),v=read();
        addEdge(u,v),addEdge(v,u);
    }
    rt=0,sumsz=minsz=n,getroot(1,0); solve(rt);
    ll scnt=0,sin=0,ans=0;
    for (re int i=1;i<=n;++i) scnt+=cnt[i];
    for (re int r=1,l=1;r<=n;++r) {
        sin+=cnt[r];
        while (l<r&&sin>scnt-sin) sin-=cnt[l],++l;
        ans+=l-1;
    }
    printf("%lld\n",ans);
    return 0;
}
最后修改:2021 年 03 月 24 日 02 : 59 PM