Luogu

LOJ

分析

根据拟阵的相关理论可以知道:$A$ 是权值最小的基当且仅当 $\forall u\in A,v\notin A$,只要 $(A\backslash\{u\})\cup\{v\}$ 是基,则 $v_u\leq v_v$;权值最大的基的情况同理。这样子就把原题中最小值、最大值的限制转化成了若干组偏序关系。

现在的问题即为 $p=2$ 时的保序回归问题。根据论文中的理论,考虑整体二分,因为要求解是整数所以可以只解决 $S=\{mid,mid+1\}$。

这个问题相当于:每个数可以等于 $mid$ 或 $mid+1$,另外选的数需要满足一些偏序关系,最小化回归代价。这可以建模成一个最小割模型来解决。

代码

// ====================================
//   author: M_sea
//   website: https://m-sea-blog.com/
// ====================================
#include <bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define debug(...) fprintf(stderr,__VA_ARGS__)
using namespace std;
typedef long long ll;

ll sqr(int x) { return 1ll*x*x; }
ll read() {
    ll X=0,w=1; char c=getchar();
    while (c<'0'||c>'9') { if (c=='-') w=-1; c=getchar(); }
    while (c>='0'&&c<='9') X=X*10+c-'0',c=getchar();
    return X*w;
}

const int N=1000+10,M=64+10;
const ll inf=0x3f3f3f3f3f3f3f3f;

int n,m,v[N],a[M],b[M]; ll c[N];

namespace LB {
    ll b[64];
    void init() { memset(b,0,sizeof(b)); }
    void insert(ll x) {
        for (int i=63;~i;--i) {
            if (!(x&(1ll<<i))) continue;
            if (!b[i]) { b[i]=x; return; }
            x^=b[i];
        }
    }
    bool query(ll x) {
        for (int i=63;~i;--i) {
            if (!(x&(1ll<<i))) continue;
            if (!b[i]) return 1;
            x^=b[i];
        }
        return 0;
    }
}

namespace H {
    struct edge { int v; ll w; int nxt; } e[N*128];
    int head[N],ecnt;
    void addEdge(int u,int v,ll w) {
        e[++ecnt]=(edge){v,w,head[u]},head[u]=ecnt;
        e[++ecnt]=(edge){u,0,head[v]},head[v]=ecnt;
    }

    int S,T,lv[N];
    bool bfs() {
        queue<int> Q; Q.push(S);
        memset(lv,0,(T+1)<<2),lv[S]=1;
        while (!Q.empty()) {
            int u=Q.front(); Q.pop();
            for (int i=head[u];i;i=e[i].nxt) {
                int v=e[i].v; ll w=e[i].w;
                if (w&&!lv[v]) lv[v]=lv[u]+1,Q.push(v);
            }
        }
        return lv[T]!=0;
    }
    ll dfs(int u,ll r) {
        if (u==T||!r) return r;
        ll add=0;
        for (int i=head[u];i;i=e[i].nxt) {
            int v=e[i].v; ll w=e[i].w;
            if (w&&lv[v]==lv[u]+1) {
                ll t=dfs(v,min(r,w));
                add+=t,r-=t,e[i].w-=t,e[i^1].w+=t;
                if (!r) break;
            }
        }
        if (!add) lv[u]=0;
        return add;
    }
    ll dinic() {
        ll res=0;
        while (bfs()) res+=dfs(S,inf);
        return res;
    }

    void init(int len) {
        S=0,T=len+1,memset(head,0,(T+1)<<2),ecnt=1;
    }
}

vector<int> E[N];
int p[N],lp[N],rp[N],id[N],ans[N];
void solve(int l,int r,int L,int R) {
    if (l>r) return;
    if (L==R) {
        for (int i=l;i<=r;++i) ans[p[i]]=L;
        return;
    }
    int mid=(L+R)>>1; H::init(r-l+1);
    for (int i=l;i<=r;++i) id[p[i]]=i-l+1;
    for (int i=l;i<=r;++i) {
        ll d=sqr(v[p[i]]-mid)-sqr(v[p[i]]-mid-1);
        if (d>0) H::addEdge(H::S,i-l+1,d);
        else H::addEdge(i-l+1,H::T,-d);
        for (int v:E[p[i]]) if (id[v]) H::addEdge(i-l+1,id[v],inf);
    }
    H::dinic(); int lc=0,rc=0;
    for (int i=l;i<=r;++i) {
        if (!H::lv[i-l+1]) lp[++lc]=p[i];
        else rp[++rc]=p[i];
    }
    for (int i=1;i<=lc;++i) p[l+i-1]=lp[i];
    for (int i=1;i<=rc;++i) p[l+lc+i-1]=rp[i];
    for (int i=l;i<=r;++i) id[p[i]]=0;
    solve(l,l+lc-1,L,mid),solve(l+lc,r,mid+1,R);
}

int main() {
    n=read(),m=read();
    for (int i=1;i<=n;++i) c[i]=read();
    for (int i=1;i<=n;++i) v[i]=read();
    for (int i=1;i<=m;++i) a[i]=read();
    for (int i=1;i<=m;++i) b[i]=read();
    for (int i=1;i<=m;++i) {
        LB::init();
        for (int j=1;j<=m;++j)
            if (j!=i) LB::insert(c[a[j]]);
        for (int j=1;j<=n;++j)
            if (j!=a[i]&&LB::query(c[j])) E[a[i]].emplace_back(j);
    }
    for (int i=1;i<=m;++i) {
        LB::init();
        for (int j=1;j<=m;++j)
            if (j!=i) LB::insert(c[b[j]]);
        for (int j=1;j<=n;++j)
            if (j!=b[i]&&LB::query(c[j])) E[j].emplace_back(b[i]);
    }
    for (int i=1;i<=n;++i) p[i]=i;
    solve(1,n,0,1e6); ll s=0;
    for (int i=1;i<=n;++i) s+=sqr(v[i]-ans[i]);
    printf("%lld\n",s);
    return 0;
}
最后修改:2021 年 01 月 12 日 10 : 22 PM