【算法】最小乘积生成树

在 $OI$ 中,一般我们见到的都是二维最小乘积生成树

具体来说,每条边有两种权值 $a_i,b_i$
需要求得一个生成树的方案 $p$ ,使得 $(\sum_{i=1}^{n-1}a[p_i])(\sum_{i=1}^{n-1}b[p_i])$ 尽可能的小

如果我们将任意一个生成树的方案丢到以 $\sum_{i=1}^{n-1}a[p_i]$ 为 $x$ 轴, $\sum_{i=1}^{n-1}b[p_i]$ 为 $y$ 轴的平面直角坐标系中

显然, 令 $val=xy$ 最小的点必然在下凸包上

如何寻找在下凸包上面的点

第一步:找到最靠近 $y/x$ 轴的两个点 $a,b$
第二步:找到距离直线(线段) $ab$ 最远的点 $c$
第三步:递归处理 $ac,bc$

如何找到距离 $ab$ 最远的 $c$

平面上距离直线最远的点可以通过在直线上任取两点,计算其叉积最大/小值来得到

对于任意在 $ab$ 左下方的点 $c$ 和点 $a,b$ 形成的三角形的面积用叉积表示为:

$$ S=-\frac{1}{2}(x_b-x_a,y_b-y_a)\times(x_c-x_a,y_c-y_a) $$

最大化 $S$ 即最小化

$$ (x_b-x_a,y_b-y_a)\times(x_c-x_a,y_c-y_a)\\ =\\ (x_b-x_a)(y_c-y_a)-(y_b-y_a)(x_c-x_a)\\ =\\ (x_by_c-x_by_a-x_ay_c+x_ay_a)-(x_cy_b-x_ay_b-x_cy_a+x_ay_a)\\ =\\ (y_a-y_b)x_c+(x_b-x_a)y_c+x_ay_b-x_by_a $$

为了使这个值尽可能的小,那么我们令横坐标 $+1$ 的贡献为 $y_a-y_b$ ,令纵坐标 $+1$ 的贡献为 $x_b-x_a$ ,在这个条件下求最小生成树即可

结束条件为 $S\leq0$ 即 $(y_a-y_b)x_c+(x_b-x_a)y_c+x_ay_b-x_by_a\ge0$

Code

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int read() {
    int ans = 0, flag = 1;
    char ch = getchar();
    while(ch > '9' || ch < '0') {if(ch == '-') flag = - flag; ch = getchar();}
    while(ch >= '0' && ch <= '9') {ans = ans * 10 + ch - '0'; ch = getchar();}
    return ans * flag;
}

const int N = 210;
const int M = 10010;

int n, m, Ans = INT_MAX;

struct Vec {
    int x, y;
}ANS;

struct Edge {
    int u, v, x, y;
    ll key;
    bool operator < (const Edge a) const {
        return key < a.key;
    }
}e[M];

struct Dsu {
    int fa[N], siz[N];
    int find(int x) {
        return fa[x] == x ? fa[x] : fa[x] = find(fa[x]);
    }
    bool merge(int a, int b) {
        a = find(a); b = find(b);
        if(a == b) return 0;
        if(siz[a] > siz[b])
            swap(a, b);
        fa[a] = b;
        siz[b] += siz[a];
        return 1;
    }
    void clear() {
        for(int i = 1; i <= n; ++ i) {
            fa[i] = i;
            siz[i] = 1;
        }
    }
}d;

Vec MST() {
    sort(e + 1, e + m + 1);
    d.clear();
    Vec ans = Vec{0, 0};
    for(int i = 1; i <= m; ++ i) {
        if(d.merge(e[i].u, e[i].v)) {
            ans.x += e[i].x;
            ans.y += e[i].y;
        }
    }
    return ans;
}

void init() {
    n = read(); m = read();
    for(int i = 1; i <= m; ++ i) {
        e[i].u = read() + 1;
        e[i].v = read() + 1;
        e[i].x = read();
        e[i].y = read();
    }
}

void solve(Vec a, Vec b) {
    for(int i = 1; i <= m; ++ i)
        e[i].key = 1ll * e[i].x * (a.y - b.y) + 1ll * e[i].y * (b.x - a.x);
    Vec p = MST();
    if(p.x * p.y < Ans) {
        ANS = p;
        Ans = p.x * p.y;
    }
    if(p.x * p.y == Ans)
        if(p.x < ANS.x)
            ANS = p;
    ll val = 1ll * (a.y - b.y) * p.x + 1ll * (b.x - a.x) * p.y + 1ll * a.x * b.y - 1ll * b.x * a.y;
    if(val >= 0) return;
    solve(a, p);
    solve(p, b);
}

int main() {
    init();
    for(int i = 1; i <= m; ++ i)
        e[i].key = e[i].x;
    Vec a = MST();
    for(int i = 1; i <= m; ++ i)
        e[i].key = e[i].y;
    Vec b = MST();
    if(a.x * a.y < Ans) {
        ANS = a;
        Ans = a.x * a.y;
    }
    if(a.x * a.y == Ans)
        if(a.x < ANS.x)
            ANS = a;

    if(b.x * b.y < Ans) {
        ANS = b;
        Ans = b.x * b.y;
    }
    if(b.x * b.y == Ans)
        if(b.x < ANS.x)
            ANS = b;
    solve(a, b);
    printf("%d %d\n", ANS.x , ANS.y);
    return 0;
}
Comments

添加新评论