Luogu

LOJ

分析

首先把第 $k$ 大转化掉:
$$
\begin{aligned}
&\sum_{S\subseteq U}S\text{ 的第 }k\text{ 大点权}\\
=&\sum_{i=1}^W\sum_{S\subseteq U}[S\text{ 的第 }k\text{ 大点权 }\geq i]\\
=&\sum_{i=1}^W\sum_{S\subseteq U}[(S\text{ 中点权 }\geq i\text{ 的点数 })\geq i]\\
\end{aligned}
$$
考虑一个 DP。设 $f_{i,j,k}$ 表示以 $i$ 为根的子树、权值 $\geq j$ 的点共有 $k$ 个的方案数。答案即为 $\sum_{i=1}^n\sum_{j=1}^W\sum_{p=k}^n f_{i,j,p}$。

转移是树形背包:
$$
f_{i,j,k}=\sum_{[d_i\geq j]+\sum_va_v=k}\prod_v\left(f_{v,j,a_v}+[a_v=0]\right)
$$
直接做是 $\mathcal{O}(n^2W)$ 的,大力卡可以卡过去。

既然是背包,考虑写成生成函数的形式。设 $F_{i,j}(x)=\sum_k f_{i,j,k}x^k$,转移变为
$$
F_{i,j}=x^{[d_i\geq j]}\prod_v\left(F_{v,j}+1\right)
$$
不妨再设一个 $g_{i,j,k}$ 表示子树 $f_{i,j,k}$ 的和,并设 $G_{i,j}(x)=\sum_k g_{i,j,k}x^k$。答案即为 $\sum_{i=k}^n[x^i](\sum_j G_{1,j})$。

看起来这个转移并没有什么可以优化的地方。但是 $F_{i,j}$ 和 $G_{i,j}$ 的最高次项都是 $x^n$,所以我们可以想办法求出 $n+1$ 个点值,然后再拉格朗日插值求出系数。

点值有一个好处,就是原来的卷积变成了对应位置相乘。那么我们相当于要支持:

  • 把所有 $F_{i,j}(x)$ 加上 $1$;
  • 把所有 $F_{i,j}(x)$ 乘上一个数;
  • 把所有 $G_{i,j}(x)$ 加上 $F_{i,j}(x)$。

使用整体 DP 的思想,对每个节点开一棵以 $j$ 为下标的线段树,维护 $F_{i,j}(x)$ 和 $G_{i,j}(x)$ 的值。具体的,我们构造一个变换 $(a,b,c,d)$,表示将 $(f,g)$ 变为 $(af+b,cf+d+g)$。手推可以知道两个变换 $(a_1,b_1,c_1,d_1)$ 和 $(a_2,b_2,c_2,d_2)$ 合并后得到 $(a_1a_2,a_2b_1+b_2,c_1+c_2a_1,c_2b_1+d_1+d_2)$,且这个合并是满足结合律的。这样子就可以求出点值,进一步得到答案了。

代码

// ====================================
//   author: M_sea
//   website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define debug(...) fprintf(stderr,__VA_ARGS__)
using namespace std;
typedef long long ll;

int read() {
    int X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int N=1666+10;
const int mod=64123;
int qpow(int a,int b) { int c=1;
    for (;b;b>>=1,a=1ll*a*a%mod) if (b&1) c=1ll*c*a%mod;
    return c;
}

int n,k,W,d[N];
vector<int> E[N];

struct alice {
    int a,b,c,d;
    alice(int a_=1,int b_=0,int c_=0,int d_=0): a(a_),b(b_),c(c_),d(d_) {}
};
alice operator +(alice x,alice y) {
    return alice(1ll*x.a*y.a%mod,(1ll*y.a*x.b+y.b)%mod,
                 (x.c+1ll*y.c*x.a)%mod,(1ll*y.c*x.b+x.d+y.d)%mod);
}
bool operator !=(alice x,alice y) {
    return x.a!=y.a||x.b!=y.b||x.c!=y.c||x.d!=y.d;
}

#define ls(o) t[o].lc
#define rs(o) t[o].rc
struct node {
    int lc,rc; alice w;
    node(int lc_=0,int rc_=0,alice w_=alice()): lc(lc_),rc(rc_),w(w_) {}
} t[N*30];
int rt[N],tot=0;

int newnode() { t[++tot]=node(); return tot; }

void pushdown(int o) {
    if (!ls(o)) ls(o)=newnode();
    if (!rs(o)) rs(o)=newnode();
    if (t[o].w!=alice()) {
        t[ls(o)].w=t[ls(o)].w+t[o].w;
        t[rs(o)].w=t[rs(o)].w+t[o].w;
        t[o].w=alice();
    }
}

void modify(int &o,int l,int r,int ql,int qr,alice w) {
    if (!o) o=newnode();
    if (ql<=l&&r<=qr) { t[o].w=t[o].w+w; return; }
    int mid=(l+r)>>1; pushdown(o);
    if (ql<=mid) modify(ls(o),l,mid,ql,qr,w);
    if (qr>mid) modify(rs(o),mid+1,r,ql,qr,w);
}
int query(int o,int l,int r,int ql,int qr) {
    if (!o) return 0;
    if (l==r) return t[o].w.d;
    int mid=(l+r)>>1,res=0; pushdown(o);
    if (ql<=mid) res=query(ls(o),l,mid,ql,qr);
    if (qr>mid) res=(res+query(rs(o),mid+1,r,ql,qr))%mod;
    return res;
}

int merge(int x,int y) {
    if (!x||!y) return x+y;
    if (!ls(x)&&!rs(x)) {
        t[y].w=t[y].w+alice(t[x].w.b,0,0,0);
        t[y].w=t[y].w+alice(1,0,0,t[x].w.d);
        return y;
    }
    if (!ls(y)&&!rs(y)) {
        t[x].w=t[x].w+alice(t[y].w.b,0,0,0);
        t[x].w=t[x].w+alice(1,0,0,t[y].w.d);
        return x;
    }
    pushdown(x),pushdown(y);
    ls(x)=merge(ls(x),ls(y)),rs(x)=merge(rs(x),rs(y));
    return x;
}
#undef ls
#undef rs

void dfs(int u,int fa,int x) {
    modify(rt[u]=0,1,W,1,W,alice(0,1,0,0));
    for (int v:E[u]) {
        if (v==fa) continue;
        dfs(v,u,x),rt[u]=merge(rt[u],rt[v]);
    }
    modify(rt[u],1,W,1,d[u],alice(x,0,0,0));
    modify(rt[u],1,W,1,W,alice(1,0,1,0));
    modify(rt[u],1,W,1,W,alice(1,1,0,0));
}

int ans[N],f[N],g[N],h[N],inv[N];
void Lagrange() {
    for (int i=1;i<=n+1;++i) inv[i]=qpow(i,mod-2);
    f[0]=1;
    for (int i=1;i<=n+1;++i) {
        for (int j=n+1;j;--j)
            f[j]=(1ll*f[j]*(mod-i)+f[j-1])%mod;
        f[0]=1ll*f[0]*(mod-i)%mod;
    }
    for (int i=1;i<=n+1;++i) {
        memcpy(h,f,sizeof(h));
        for (int j=0;j<=n;++j)
            h[j]=1ll*(mod-inv[i])*h[j]%mod,h[j+1]=(h[j+1]-h[j]+mod)%mod;
        int c=1;
        for (int j=1;j<=n+1;++j) {
            if (i==j) continue;
            if (j<i) c=1ll*c*inv[i-j]%mod;
            else c=1ll*c*(mod-inv[j-i])%mod;
        }
        c=1ll*c*ans[i]%mod;
        for (int j=0;j<=n+1;++j) g[j]=(g[j]+1ll*h[j]*c)%mod;
    }
}

int main() {
    n=read(),k=read(),W=read();
    for (int i=1;i<=n;++i) d[i]=read();
    for (int i=1;i<n;++i) {
        int u=read(),v=read();
        E[u].emplace_back(v),E[v].emplace_back(u);
    }
    for (int i=1;i<=n+1;++i) {
        tot=0,dfs(1,0,i);
        ans[i]=query(rt[1],1,W,1,W);
    }
    Lagrange(); int ans=0;
    for (int i=k;i<=n;++i) ans=(ans+g[i])%mod;
    printf("%d\n",ans);
    return 0;
}
最后修改:2021 年 01 月 05 日 09 : 59 PM