洛谷3233 [HNOI2014]世界树

Luogu

BZOJ

分析

有个m[1]+m[2]+...+m[q]<=300000的条件,显然是虚树。

问题在于怎么DP。

首先来两次Dfs,求出虚树上每个点被哪个点控制。

然后对于虚树上每一条边$(u,v)$,

  • 如果$u$和$v$被同一个点控制,直接把$u$,$v$所属的点的贡献加上这两个点不在虚树中的儿子的大小。
  • 否则,倍增求出一个分界点使得上面的点被$u$控制,下面的点被$v$控制,然后分别贡献。

新的问题在于怎么求一个点不在虚树中的儿子的大小。这里我真的不想写了qwq,请看这位大佬的题解

还有,memset极其慢,可以用别的方法清空,详见代码。

代码

//It is made by M_sea
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
#define re register
using namespace std;

inline int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int MAXN=300000+10;
const int INF=1e9;

struct Edge { int v,nxt; };
Edge e[MAXN<<1];
int head[MAXN],cnt=0;

inline void addEdge(int u,int v) {
    e[++cnt].v=v,e[cnt].nxt=head[u],head[u]=cnt;
}

int n,q;
int a[MAXN],b[MAXN],fg[MAXN];
int f[20][MAXN],dep[MAXN],dfn[MAXN],sz[MAXN];
int lg[MAXN];
int belong[MAXN],sur[MAXN];
int sta[MAXN],top=1;
int ans[MAXN];
int dfs_clock=0;

inline int cmp(int a,int b) { return dfn[a]<dfn[b]; }

inline void dfs(int u,int fa) {
    dep[u]=dep[fa]+1,f[0][u]=fa,dfn[u]=++dfs_clock,sz[u]=1;
    for (re int i=1;(1<<i)<=dep[u];++i) f[i][u]=f[i-1][f[i-1][u]];
    for (re int i=head[u];i;i=e[i].nxt) {
        int v=e[i].v; if (v==fa) continue;
        dfs(v,u); sz[u]+=sz[v];
    }
}

inline int LCA(int a,int b) {
    if (dep[a]<dep[b]) swap(a,b);
    for (re int i=19;~i;--i)
        if (dep[f[i][a]]>dep[b]) a=f[i][a];
    if (dep[a]!=dep[b]) a=f[0][a];
    for (re int i=19;~i;--i)
        if (f[i][a]!=f[i][b]) a=f[i][a],b=f[i][b];
    if (a!=b) a=f[0][a],b=f[0][b];
    return a;
}

inline int dis(int a,int b) { return dep[a]+dep[b]-(dep[LCA(a,b)]<<1); }

inline void clearGraph() {
    memset(head,0,sizeof(head));
    cnt=0;
}

inline void dfs1(int u) {
    belong[u]=fg[u]?u:0; sur[u]=sz[u];
    for (re int i=head[u];i;i=e[i].nxt) {
        int v=e[i].v; dfs1(v);
        int d1=dep[belong[v]]-dep[u];
        int d2=belong[u]?dep[belong[u]]-dep[u]:INF;
        if (d1<d2||(d1==d2&&belong[v]<belong[u])) belong[u]=belong[v];
    }
}

inline void dfs2(int u) {
    for (re int i=head[u];i;i=e[i].nxt) {
        int v=e[i].v;
        int d1=dis(belong[u],v),d2=dis(belong[v],v);
        if (d1<d2||(d1==d2&&belong[u]<belong[v])) belong[v]=belong[u];
        dfs2(v);
    }
}

inline void dp(int u) {
    for (re int i=head[u];i;i=e[i].nxt) {
        int v=e[i].v; dp(v);
        int s=v,mid=v;
        for (re int i=lg[dep[v]];i>=0;--i)
            if (dep[f[i][s]]>dep[u]) s=f[i][s];
        sur[u]-=sz[s];
        if (belong[u]==belong[v]) ans[belong[u]]+=sz[s]-sz[v];
        else { /* 倍增求u-v链上的分界点 */ 
            for (re int i=lg[dep[v]];i>=0;--i) {
                int now=f[i][mid];
                if (dep[now]<=dep[u]) continue;
                int d1=dis(now,belong[v]),d2=dis(now,belong[u]);
                if (d1<d2||(d1==d2&&belong[v]<belong[u])) mid=now;
            }
            ans[belong[u]]+=sz[s]-sz[mid],ans[belong[v]]+=sz[mid]-sz[v];
        }
    }
    ans[belong[u]]+=sur[u]; head[u]=0; /* memset太慢了,在这里顺便清掉 */
}

int main() {
    n=read();
    for (re int i=2;i<=n;++i) lg[i]=lg[i>>1]+1;
    for (re int i=1,u,v;i<n;++i) {
        u=read(),v=read();
        addEdge(u,v),addEdge(v,u);
    }
    dfs(1,0);
    memset(head,0,sizeof(head));
    
    q=read();
    while (q--) {
        int k=read();
        for (re int i=1;i<=k;++i) fg[b[i]=a[i]=read()]=1;
        sort(a+1,a+k+1,cmp); sta[top=1]=1; cnt=0;
        for (re int i=1;i<=k;++i) {
            int x=a[i],p=LCA(sta[top],x);
            while (dep[p]<dep[sta[top]]) {
                if (dep[p]>=dep[sta[top-1]]) {
                    addEdge(p,sta[top--]);
                    if (sta[top]!=p) sta[++top]=p;
                    break;
                }
                addEdge(sta[top-1],sta[top]); --top;
            }
            if (sta[top]!=x) sta[++top]=x;
        }
        while (top>1) { addEdge(sta[top-1],sta[top]); --top; }
        dfs1(1),dfs2(1),dp(1);
        
        for (re int i=1;i<=k;++i) {
            printf("%d ",ans[b[i]]);
            /* memset太慢了... */
            fg[b[i]]=ans[b[i]]=0;
        }
        putchar('\n');
    }
    return 0;
}
最后修改:2019 年 09 月 24 日 08 : 52 PM

4 条评论

  1. wsm000

    memset的复杂度不是o(n)的么?不是应该已经到了理论下界了吗?

    1. M_sea
      @wsm000

      但是你每次 memset(ans+1,ans+n+1,0) 这样子每次询问都是 $\mathcal{O}(n)$ 的,然后 $Q$ 次询问复杂度就变成 $\mathcal{O}(Qn)$

  2. smy

    偷懒差评

    1. M_sea
      @smy

      真的不想写啊qwq

发表评论