【题解】Atcoder AGC 005F Many Easy Problems

Problem

给定一棵 $n$ 个节点的树
设某个顶点集合为 $S$ ,定义函数 $f(S)$ 为树上包含这些顶点的最小联通块大小

求 $|S|$ 为 $1,2,\cdots,n$ 时 $f(S)$ 的和,即:

$$ Ans_i=\sum_{|S|=i}f(S) $$

Thought

假设当前顶点集合大小为 $|S|$ ,那么一个节点不被包含在该联通块中当且仅当将该点看做根节点时,所有集合中的顶点在该节点的某个子树中;反之必然被包含

也就是说我们需要求出:

$$ {n\choose|S|}-\sum_{y\in son_x}{siz[y]\choose|S|} $$

显然,不管是哪个节点的子树,最终对答案的贡献都之和其大小有关

那么设 $cnt[i]$ 为大小 $=i$ 的子树个数,则最终答案为:

$$ n{n\choose{|S|}}-\sum_{i=1}cnt[i]*{i\choose|S|} $$

考虑如何维护后面部分的贡献,即:

$$ \sum_{i=1}^{n-1}cnt[i]*{i\choose|S|}\\ \sum_{i=1}^{n-1}cnt[i]*\frac{i!}{|S|!(i-|S|)!}\\ \frac{1}{|S|!}\sum_{i=1}^{n-1}cnt[i]i!*\frac{1}{(i-|S|)!} $$

设 $A[i]=cnt[i]i!,B[i]=\frac{1}{i!}$

$$ \frac{1}{|S|!}\sum_{i=1}^{n-1}A[i]*B[i-|S|] $$

翻转 $B$ 得到:

$$ \frac{1}{|S|!}\sum_{i=1}^{n-1}A[i]*B[|S|-i+n] $$

这样使得下标之和为 $|S|+n$ ,同时我们令 $B[n+1]-B[2n]$ 均为 $0$,防止出现 $|S|>i$ 的情况

最后用 $NTT$ 计算出该式,其中模数为 $924844033$ ,原根为 $5$

Code

#include <bits/stdc++.h>
using namespace std;
int read() {
    int ans = 0, flag = 1;
    char ch = getchar();
    while(ch > '9' || ch < '0') {if(ch == '-') flag = - flag; ch = getchar();}
    while(ch >= '0' && ch <= '9') {ans = ans * 10 + ch - '0'; ch = getchar();}
    return ans * flag;
}
const int N = 200010;
const int mod = 924844033;
const int G = 5;
struct Edge {
    int t, n;
}e[N << 1];
int head[N], Etot;

int siz[N];
int cnt[N << 2], invProRev[N << 2], rev[N << 2];
int inv[N], invPro[N], pro[N];
int n;

void addedge(int u, int v) {
    e[++ Etot] = {v, head[u]};
    head[u] = Etot;
}
void dfs(int x, int f) {
    siz[x] = 1;
    for(int i = head[x]; i; i = e[i].n) {
        if(e[i].t == f) continue;
        dfs(e[i].t, x);
        ++ cnt[siz[e[i].t]];
        siz[x] += siz[e[i].t];
    }
    if(n - siz[x]) ++ cnt[n - siz[x]];
}
int qpow(int a, int b) {
    int ans = 1;
    while(b) {
        if(b & 1) ans = 1ll * ans * a % mod;
        a = 1ll * a * a % mod;
        b >>= 1;
    }
    return ans;
}

void dft(int *a, int n, int f) {
    for(int i = 0; i < n; ++ i)
        if(i < rev[i])
            swap(a[i], a[rev[i]]);
    for(int i = 1; i < n; i <<= 1) {
        int gn = qpow(G, (mod - 1) / (i << 1));
        if(f == - 1) gn = qpow(gn, mod - 2);
        for(int j = 0; j < n; j += (i << 1)) {
            int g = 1, x, y;
            for(int k = 0; k < i; ++ k, g = 1ll * g * gn % mod) {
                x = a[j + k]; y = 1ll * a[i + j + k] * g % mod;
                a[j + k] = (x + y) % mod;
                a[i + j + k] = (x - y + mod) % mod;
            }
        }
    }
    if(f == - 1) {
        int ny = qpow(n, mod - 2);
        for(int i = 0; i < n; ++ i)
            a[i] = 1ll * a[i] * ny % mod;
    }
}
int main() {
    n = read();
    for(int i = 1; i < n; ++ i) {
        int u = read(), v = read();
        addedge(u, v);
        addedge(v, u);
    }
    dfs(1, 0);
    pro[0] = invPro[0] = 1;
    pro[1] = inv[1] = invPro[1] = 1;
    for(int i = 2; i <= n; ++ i) {
        pro[i] = 1ll * pro[i - 1] * i % mod;
        inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
        invPro[i] = 1ll * invPro[i - 1] * inv[i] % mod;
    }

    for(int i = 1; i <= n; ++ i) {
        cnt[i] = 1ll * cnt[i] * pro[i] % mod;
        invProRev[i] = invPro[n - i];
    }
    int nn, l = 0;
    for(nn = 1; nn < 2 * n; nn <<= 1) ++ l;
    for(int i = 1; i < nn; ++ i)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
    dft(cnt, nn, 1); dft(invProRev, nn, 1);
    for(int i = 0; i < nn; ++ i)
        cnt[i] = 1ll * cnt[i] * invProRev[i] % mod;
    dft(cnt, nn, - 1); 
    for(int i = 1; i <= n; ++ i) {
        int sum = 1ll * n * pro[n] % mod * invPro[i] % mod * invPro[n - i] % mod;
        sum = (sum - 1ll * invPro[i] * cnt[n + i] % mod + mod) % mod;
        printf("%d\n", sum);
    }
    return 0;
}
Comments

添加新评论