【题解】LOJ 2542 「PKUWC2018」随机游走

请注意,本文编写于 96 天前,最后修改于 96 天前,其中某些信息可能已经过时。

Problem

给定一棵 $n$ 个节点的 ,你从点 $x$ 出发,每次等概率随机选择一条边走
求将点集 $S$ 中所有点至少走过一次的期望游走距离

多次询问

Thought

结束状态显然是将 $S$ 中所有的点走过至少一遍,换句话说,就是求 $S$ 中最后一个到达的节点的期望,即 $MAX\{f(x),x\in S\}$

最大值不好处理,不如使用 $MIN-MAX$ 容斥将其转化为求最小值,即

$$ MAX(S)=\sum_{S'\subseteq S}(-1)^{|S'|+1}MIN(S') $$

计算 $MIN(S)$ 可以在树上高斯消元,时间复杂度 $O(n2^n\log mod)$

然后高维前缀和一下 ,时间复杂度 $O(n2^n)$

那么对于每个询问可以直接 $O(1)$ 回答,总复杂度 $O(n2^n+q)$

#include <bits/stdc++.h>
using namespace std;
int read() {
    int ans = 0, flag = 1;
    char ch = getchar();
    while(ch > '9' || ch < '0') {if(ch == '-') flag = - flag; ch = getchar();}
    while(ch >= '0' && ch <= '9') {ans = ans * 10 + ch - '0'; ch = getchar();}
    return ans * flag;
}
const int N = 19;
const int mod = 998244353;

int n, m, rt;
int S;

int INV[N];

int var[N];
int con[N];
bool vis[N];

int deg[N];

int g[1 << 18];
int num[1 << 18];

queue <int> que;

int fa[N];

struct Edge {
    int t, n;
}e[N << 1];
int head[N], Etot;

void prepare() {
    INV[1] = 1;
    for(int i = 2; i < N; ++ i)
        INV[i] = 1ll * (mod - mod / i) * INV[mod % i] % mod;
}

void addedge(int u, int v) {
    e[++ Etot] = {v, head[u]};
    head[u] = Etot;
}

void dfs(int x) {
    for(int i = head[x]; i; i = e[i].n) {
        if(e[i].t == fa[x]) continue;
        fa[e[i].t] = x;
        dfs(e[i].t);
    }
}

void clear() {
    for(int i = 1; i <= n; ++ i) {
        vis[i] = 0;
        var[i] = 0;
        con[i] = 0;
    }
}

int M(int x) {
    while(x < 0) x += mod;
    while(x >= mod) x -= mod;
    return x;
}

int qpow(int a, int b) {
    int ans = 1;
    while(b) {
        if(b & 1) ans = 1ll * ans * a % mod;
        a = 1ll * a * a % mod;
        b >>= 1;
    }
    return ans;
}

int inv(int x) {
    if(x < N) return INV[x];
    return qpow(x, mod - 2);
}

void dfs1(int x) {
    if(vis[x]) return;
    for(int i = head[x]; i; i = e[i].n) {
        if(e[i].t == fa[x]) continue;
        dfs1(e[i].t);
    }
    for(int i = head[x]; i; i = e[i].n) {
        if(e[i].t == fa[x]) continue;
        con[x] += 1ll * con[e[i].t] * inv(deg[x]) % mod;
        if(con[x] >= mod) con[x] -= mod;
        var[x] += 1ll * var[e[i].t] * inv(deg[x]) % mod;
        if(var[x] >= mod) var[x] -= mod;
    }
    con[x] += 1;
    if(con[x] >= mod)
        con[x] -= mod;
    
    con[x] = 1ll * con[x] * inv(M(1 - var[x])) % mod;
    if(fa[x]) {
        int f = inv(deg[x]);
        f = 1ll * f * inv(M(1 - var[x])) % mod;
        var[x] = f;
    }
    else var[x] = 0;
}

void cal(int s) {
    clear();
    for(int i = 1; i <= n; ++ i)
        if(s & (1 << (i - 1))) {
            vis[i] = 1;
            ++ num[s];
        }
    dfs1(rt);
    g[s] = 1ll * qpow(mod - 1, num[s] + 1) * con[rt] % mod;
}

void FWT_OR(int *a, int n, int f) {
    for(int i = 1; i < n; i <<= 1)
        for(int j = 0; j < n; j += (i << 1))
            for(int k = 0; k < i; ++ k) {
                int x = a[j + k], y = a[i + j + k];
                a[i + j + k] = M(f * x + y);
            }
}

int main() {
    prepare();
    n = read(); m = read(); rt = read();
    for(int i = 1; i < n; ++ i) {
        int u = read(), v = read();
        addedge(u, v);
        addedge(v, u);
        ++ deg[u];
        ++ deg[v];
    }

    dfs(rt);
    S = 1 << n;
    for(int i = 1; i < S; ++ i)
        cal(i);
    FWT_OR(g, S, 1);
    while(m --) {
        int k = read(), x = 0;
        while(k --)
            x |= 1 << (read() - 1);
        printf("%d\n", g[x]);
    }
    return 0;
}
Comments

添加新评论