【算法】虚树入门

请注意,本文编写于 99 天前,最后修改于 52 天前,其中某些信息可能已经过时。

让我们先看一道例题:
[SDOI2011]消耗战 (太经典了基本上都知道这道题了吧...

大致题意:
给定 $n$ 个点的一棵树,每条边有边权
每次询问给出一个点集 $H$ ,可以选择一些边截断,求出使 $H$ 中的点和 $1$ 号节点不连通的最小边权和

显然,树形 DP 是可行的

以 $1$ 号节点为根建树
用 $f[i]$ 表示使 $i$ 点为根的子树中所有关键点和 $1$ 号节点不连通的最小费用
设该点连向其父亲的边权为 $c[i]$

若 $i$ 为关键点, $f[i]=c[i]$
若 $i$ 不为关键点, $f[i]=\min(f[i],\sum f[j])|j\in son(i)$

答案为 $f[1]$

这样的 DP 总时间复杂度是 $O(nm)$ 的,显然在这道题 $n\leq250000,m\leq500000$ 的数据范围下无法通过

但是我们发现, $\sum|H|\leq500000$ ,可以尝试用虚树解决

介绍

简单来讲,虚树将每次询问的复杂度压缩到与关键点个数同阶
虚树中将会包含所有的关键点,任意两个关键点之间的 $LCA$

同时将原来的非关键点的信息保存在关键点之间的连边上
这样就保证在不损失信息的情况下尽可能的压缩树的大小

构造

预备知识:求 $LCA$ (倍增树剖RMQ)

常见的做法是用栈来维护当前链

仍然以 $1$ 号节点为根建树,加入一个 SuperRoot 作为 $1$ 号节点的父亲
将所有关键点的 DFS 序跑出来

在栈中,一开始只有一个 SuperRoot
然后我们按照关键点的 DFS 序加入点,设当前点为 $p$

根据 $p$ 和当前栈顶的节点 $q$ 的关系,分为两种情况

  1. $q$ 是 $p$ 的祖先:直接将 $p$ 入栈
  2. $q$ 不是 $p$ 的祖先:需要加入 $LCA(q,p)$

是否为祖先一般来说直接用 $LCA$ 判断下即可
第一种操作比较简单,考虑如何维护第二种情况

显然,如果 $q$ 不是 $p$ 的祖先,那么 $q$ 的子树中所有关键点已经被处理完了
那么,设 $LCA(q,p)=c$

那么在栈中所有深度在 $c$ 以下的关键点都没有意义了,将其相邻两点的边建出来,然后在栈中弹出它们

最后插入 $c$ 和 $p$ ,作为新链的底端

重复以上过程,直到无新加入的关键点,退栈即可

作用

显然,在建出来虚树之后,每一个询问直接在虚树上 DP 即可,每次复杂度 $O(siz)$ ,其中 $siz$ 和 $|H|$ 同阶

这样就可以完美的解决之前的题目了

代码

[scode type="lblue"]代码很久以前的了,有点丑.jpg 找时间重写下[/scode]

#include <bits/stdc++.h>
using namespace std;
const int N = 300000;
const int INF = 2100000000;
typedef long long ll;
typedef pair<int, ll> p;
int dfn[N], fa[N][23], sta[N], top, n, m, deep[N], DFN, maxdeep, maxn, h[N], key[N];
ll f[N], sum[N][23];
struct edge {
    int head[N];
    int t[N << 1];
    int n[N << 1];
    ll c[N << 1];
    int tot;
    void addedge(int u, int v, ll w) {
        ++tot;
        t[tot] = v;
        c[tot] = w;
        n[tot] = head[u];
        head[u] = tot;
    }
} e, ne;
int read();
void dfs(int);
void prepare();
bool cmp(int, int);
p lca(int, int);
void dp(int);
int main() {
    n = read();
    for (int i = 1; i < n; i++) {
        int x = read(), y = read();
        ll z = read();
        e.addedge(x, y, z);
        e.addedge(y, x, z);
    }
    deep[1] = 1;
    dfs(1);
    prepare();
    m = read();
    while (m--) {
        int k = read();
        for (int i = 1; i <= k; i++) h[i] = read(), key[h[i]] = 1;
        sort(h + 1, h + k + 1, cmp);
        ne.tot = 0;
        ne.head[1] = 0;
        sta[top = 1] = 1;
        for (int i = 1; i <= k; i++) {
            if (h[i] == 1)
                continue;
            p l = lca(h[i], sta[top]);
            if (l.first != sta[top]) {
                while (dfn[l.first] < dfn[sta[top - 1]]) {
                    p nl = lca(sta[top - 1], sta[top]);
                    ne.addedge(sta[top - 1], sta[top], nl.second);
                    --top;
                }
                if (dfn[l.first] > dfn[sta[top - 1]]) {
                    ne.head[l.first] = 0;
                    p nl = lca(l.first, sta[top]);
                    ne.addedge(l.first, sta[top], nl.second);
                    sta[top] = l.first;
                } else {
                    p nl = lca(l.first, sta[top]);
                    ne.addedge(l.first, sta[top--], nl.second);
                }
            }
            ne.head[h[i]] = 0;
            sta[++top] = h[i];
        }
        for (int i = 1; i < top; i++) {
            p l = lca(sta[i], sta[i + 1]);
            ne.addedge(sta[i], sta[i + 1], l.second);
        }
        dp(1);
        for (int i = 1; i <= k; i++) key[h[i]] = 0;
    }
    return 0;
}
int read() {
    int ans = 0, flag = 1;
    char ch = getchar();
    while ((ch > '9' || ch < '0') && ch != '-') ch = getchar();
    if (ch == '-')
        flag = -1, ch = getchar();
    while (ch >= '0' && ch <= '9') ans = ans * 10 + ch - '0', ch = getchar();
    return ans * flag;
}
void dfs(int x) {
    for (int i = e.head[x]; i; i = e.n[i]) {
        if (e.t[i] == fa[x][0])
            continue;
        dfn[e.t[i]] = ++DFN;
        fa[e.t[i]][0] = x;
        sum[e.t[i]][0] = e.c[i];
        deep[e.t[i]] = deep[x] + 1;
        maxdeep = max(maxdeep, deep[e.t[i]]);
        dfs(e.t[i]);
    }
}
void prepare() {
    for (maxn = 0; (1 << maxn) <= maxdeep; maxn++)
        ;
    for (int l = 1; l <= maxn; l++)
        for (int i = 1; i <= n; i++) {
            fa[i][l] = fa[fa[i][l - 1]][l - 1];
            sum[i][l] = min(sum[i][l - 1], sum[fa[i][l - 1]][l - 1]);
        }
}
bool cmp(int a, int b) { return dfn[a] < dfn[b]; }
p lca(int x, int y) {
    ll ans = INF;
    if (deep[x] < deep[y])
        swap(x, y);
    for (int i = maxn; i >= 0; i--)
        if (deep[fa[x][i]] >= deep[y]) {
            ans = min(ans, sum[x][i]);
            x = fa[x][i];
        }
    if (x == y)
        return (p){ x, ans };
    for (int i = maxn; i >= 0; i--)
        if (fa[x][i] != fa[y][i]) {
            ans = min(ans, sum[x][i]);
            ans = min(ans, sum[y][i]);
            x = fa[x][i];
            y = fa[y][i];
        }
    ans = min(ans, sum[x][0]);
    ans = min(ans, sum[y][0]);
    return (p){ fa[x][0], ans };
}
void dp(int x) {
    f[x] = 0;
    for (int i = ne.head[x]; i; i = ne.n[i]) {
        dp(ne.t[i]);
        if (key[ne.t[i]])
            f[x] += ne.c[i];
        else
            f[x] += min(ne.c[i], f[ne.t[i]]);
    }
    if (x == 1)
        printf("%lld\n", f[1]);
}
Comments

添加新评论