分析
问题相当于选出恰好 $k+1$ 条点不相交的路径使得权值和最大,这里 $1$ 个点也算一条路径。
先考虑一个朴素 DP。设 $dp_{i,j,0/1/2}$ 表示以 $i$ 为根的子树、选了 $j$ 条路径、$i$ 的度数为 $0/1/2$ 的最大权值和,转移讨论各种情况把子树合并进来即可。
设 $f(k)$ 为选恰好 $k$ 条点不交路径时的最大权值和,通过打表或者感性理解可以知道 $f(k)$ 是凸的,于是可以考虑 WQS 二分。
具体的,二分斜率 $m$,则我们需要求 $b=f(x)-mx$ 的最大值。
这个东西相当于每条链有 $-m$ 的额外价值,仍然可以 DP。设 $dp_{i,0/1/2}$ 表示以 $i$ 为根的子树、$i$ 的度数为 $0/1/2$ 的最大权值和以及此时最小选的链数,转移同样讨论各种情况把子树合并进来即可。
代码
// ====================================
// author: M_sea
// website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define debug(...) fprintf(stderr,__VA_ARGS__)
using namespace std;
typedef long long ll;
int read() {
int X=0,w=1; char c=getchar();
while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
return X*w;
}
const int N=300000+10;
const int inf=0x3f3f3f3f;
const ll infll=0x3f3f3f3f3f3f3f3f;
int n,k;
struct edge { int v,w,nxt; } e[N<<1];
int head[N];
void addEdge(int u,int v,int w) {
static int cnt=0;
e[++cnt]=(edge){v,w,head[u]},head[u]=cnt;
}
struct alice {
ll w; int c;
alice(ll w_=0,int c_=0): w(w_),c(c_) {}
};
bool operator <(alice a,alice b) { return a.w<b.w||(a.w==b.w&&a.c>b.c); }
bool operator >(alice a,alice b) { return b<a; }
alice operator +(alice a,alice b) { return alice(a.w+b.w,a.c+b.c); }
alice f[N][3],g[3];
void dfs(int u,int fa,const ll mid) {
f[u][0]=alice(0,0),f[u][1]=alice(-infll,inf),f[u][2]=alice(-mid,1);
for (int i=head[u];i;i=e[i].nxt) {
int v=e[i].v,w=e[i].w; if (v==fa) continue;
dfs(v,u,mid);
for (int j=0;j<3;++j) g[j]=alice(-infll,inf);
for (int j=0;j<3;++j) g[0]=max(g[0],f[u][0]+f[v][j]);
g[1]=max(g[1],f[u][0]+f[v][0]+alice(w-mid,1));
g[1]=max(g[1],f[u][0]+f[v][1]+alice(w,0));
for (int j=0;j<3;++j) g[1]=max(g[1],f[u][1]+f[v][j]);
g[2]=max(g[2],f[u][1]+f[v][0]+alice(w,0));
g[2]=max(g[2],f[u][1]+f[v][1]+alice(w+mid,-1));
for (int j=0;j<3;++j) g[2]=max(g[2],f[u][2]+f[v][j]);
for (int j=0;j<3;++j) f[u][j]=g[j];
}
}
int main() {
n=read(),k=read()+1;
for (int i=1;i<n;++i) {
int u=read(),v=read(),w=read();
addEdge(u,v,w),addEdge(v,u,w);
}
ll L=-1e12,R=1e12;
while (L<R) {
ll mid=(L+R)>>1; dfs(1,0,mid);
if (max({f[1][0],f[1][1],f[1][2]}).c<=k) R=mid;
else L=mid+1;
}
dfs(1,0,L); printf("%lld\n",max({f[1][0],f[1][1],f[1][2]}).w+k*L);
return 0;
}