【题解】BZOJ 4911 [SDOI2007]切树游戏

Problem

给定一棵 $n$ 个节点的树,每个节点有权值 $a[i]$

现在进行 $m$ 次操作,一共有两种情况:

  1. Change x y 将编号为 $x$ 的节点权值修改为 $y$
  2. Query k 询问异或和为 $k$ 的非空子图的数量

Thought

Subtask1

$n\leq 2000,m=64,q\leq64$

HDU 5909 Tree Cutting

时间复杂度 $O(m\log mn)$
可以拿到 $20pt$

Subtask2,3

$n\leq 30000,m=128,q\leq30000$ 树的形态为一条链

显然链的非空子图必然是一条链
那么如果求出所有前缀异或值,丢到桶里,只需要计算能够异或出 $k$ 的二元组个数即可

如何进行修改操作

考虑对序列进行分块,每个块 $O(m)$ 修改
时间复杂度 $O(q\sqrt nm+qm)$

把块分大点呗,对于散块来说,时间是 $O(1)$ 的,对于整块来说,时间是 $O(m)$ 的,那么将块的大小分为 $\sqrt{nm}$ ,那么时间复杂度降低为 $O(q\sqrt{nm}+qm)$ 应该能跑过了吧...

正解

$n\leq 30000,m=128,q\leq30000$

因为有修改,那么每次 $O(n)$ 的暴力修改复杂度是不可接受的
考虑如何才能在有修改的条件下快速统计答案

一个思路是利用链分治
每次修改后只需要重新考虑修改的 $\log n$ 条链的答案
这样的方法就比较优秀

设 $f(i,k)$ 表示在 $i$ 的子树中有多少个过 $i$ 的异或和为 $k$ 的联通块
设 $g(i,k)$ 表示在 $i$ 的子树中有多少个异或和为 $k$ 的联通块

转移显然:

$$ f(i,k)=\frac{1}{2}\sum_{b=0}^{m}\sum_{u\in son(i), v\in son(i),u<v}f(u,b)f(v,k\oplus b\oplus a[i]) $$

上面的式子显然是可以用 FWT 进行优化,不另行列出

如果分别用 $F_i(z)$ 表示 $f(i,*)$的生成函数,即有:

$$ F_i(z)=z^{a[i]}\times\prod_{u\in son(i)}(F_u(z)+z^0) $$

结合链分治,可以令 $LF_i(z)=\prod_{u\in lightson(i)}(F_u(z)+z^0)$ ,即有转移:

$$ F_i(z)=z^{a[i]}\times (F_{son[i]}(z)+z^0)\times LF_i(z)\\ $$

令 $G_i(z)$ 为重链上 $F_i(z)$ 的和,即:

$$ G_i(z)=G_{son[i]}(z)+F_i(z) $$

那么最终答案为所有重链顶的答案之和

对于一条重链上相邻的两个点,其 $F(z),G(z)$ 实际上是可以通过矩阵进行转移的:

$$ \begin{pmatrix} z^{a[fa]}\times LF_{fa}(z) & 0 & z^{a[fa]}\times LF_{fa}(z)\\ z^{a[fa]}\times LF_{fa}(z) & z^0 & z^{a[fa]}\times LF_{fa}(z)\\ 0 & 0 & z^0\\ \end{pmatrix} \begin{pmatrix} F_{son}\\ G_{son}\\ z^0\\ \end{pmatrix} = \begin{pmatrix} F_{fa}\\ G_{fa}\\ z^0\\ \end{pmatrix} $$

所以令:

$$ H_{x}(z)=z^{a[x]}\times LF_{x}(z)\\ M_i= \begin{pmatrix} H_i(z) & 0 & H_i(z)\\ H_i(z) & z^0 & H_i(z)\\ 0 & 0 & z^0\\ \end{pmatrix} $$

通过线段树维护矩阵就能很好的处理修改和求值操作

进一步分析发现,两个矩阵相乘后的结果也是有规律的:

令:

$$ M_x= \begin{pmatrix} A_x & 0 & B_x\\ C_x & z^0 & D_x\\ 0 & 0 & z^0\\ \end{pmatrix} $$

$$ M_{fa}M_{son}= \begin{pmatrix} A_{fa} & 0 & B_{fa}\\ C_{fa} & z^0 & D_{fa}\\ 0 & 0 & z^0\\ \end{pmatrix} \begin{pmatrix} A_{son} & 0 & B_{son}\\ C_{son} & z^0 & D_{son}\\ 0 & 0 & z^0\\ \end{pmatrix} =\\ \begin{pmatrix} A_{fa}A_{son} & 0 & A_{fa}B_{son}+B_{fa}\\ C_{fa}A_{son} + C_{son} & z^0 & C_{fa}B_{son} + D_{son} + D_{fa}\\ 0 & 0 & z^0 \end{pmatrix} $$

发现矩阵非常数部分为封闭的,所以只需要维护四个位置的转移,同样使用上述两矩阵相乘的方法进行计算,可以达到单次合并 $O(m)$ 的复杂度

如果要查询一条链的答案,即计算:

$$ M_{1}M_{2}M_{3}M_{\cdots} \begin{pmatrix} 0\\ 0\\ z^{0}\\ \end{pmatrix} $$

更新就跳上去更新数组 $G,F,H$ 即可

Code

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int read() {
    int ans = 0, flag = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9') {if(ch == '-') flag = - flag; ch = getchar();}
    while(ch >= '0' && ch <= '9') {ans = ans * 10 + ch - '0'; ch = getchar();}
    return ans * flag;
}
const int N = 30010;
const int mod = 10007;

int n, m, q;
int v[N];
bool One[128];
int Ans[128];

int siz[N], top[N], low[N], fa[N], son[N];
int dfn[N], DFN;

int F[128];
int G[128];

int HZ[N][128];
int H[N][128];
int f[N][128];
int g[N][128];

int inv[mod];

struct Graph {
    int t[N << 1], n[N << 1], head[N], tot;
    void add(int u, int v) {
        ++ tot;
        t[tot] = v;
        n[tot] = head[u];
        head[u] = tot;
    }
}Gph;

int lowbit(int x) {
    return x & (- x);
}
void bef() {
    for(int i = 1; i < 128; ++ i)
        One[i] = One[i - lowbit(i)] ^ 1;
    inv[1] = 1;
    for(int i = 2; i < mod; ++ i)
        inv[i] = (mod - mod / i) * inv[mod % i] % mod;
}
int M(int x) {
    while(x >= mod) x -= mod;
    while(x < 0) x += mod;
    return x;
}
void init() {
    n = read(); m = read();
    for(int i = 1; i <= n; ++ i)
        v[i] = read();
    for(int i = 1; i < n; ++ i) {
        int x = read(), y = read();
        Gph.add(x, y);
        Gph.add(y, x);
    }
}

int A[128];
int B[128];
int C[128];
int D[128];
int a[128];
int b[128];
int c[128];
int d[128];

#define L (x << 1)
#define R (x << 1 | 1)
#define mid ((l + r) >> 1)
struct Seg {
    int a[N << 2][128];
    int b[N << 2][128];
    int c[N << 2][128];
    int d[N << 2][128];
    void upd(int x) {
        for(int i = 0; i < m; ++ i) {
            a[x][i] = a[L][i] * a[R][i] % mod;
            b[x][i] = (a[L][i] * b[R][i] + b[L][i]) % mod;
            c[x][i] = (c[L][i] * a[R][i] + c[R][i]) % mod;
            d[x][i] = (c[L][i] * b[R][i] + d[R][i] + d[L][i]) % mod;
        }
    }
    void ins(int x, int l, int r, int pos) {
        if(l == r) {
            for(int i = 0; i < m; ++ i)
                if(HZ[pos][i])
                    a[x][i] = b[x][i] = c[x][i] = d[x][i] = 0;
                else
                    a[x][i] = b[x][i] = c[x][i] = d[x][i] = H[pos][i];
        }
        else {
            if(dfn[pos] <= mid)
                ins(L, l, mid, pos);
            else
                ins(R, mid + 1, r, pos);
            upd(x);
        }
    }
    void query(int x, int l, int r, int ll, int rr) {
        if(l >= ll && r <= rr) {
            for(int i = 0; i < m; ++ i) {
                A[i] = :: a[i] * a[x][i] % mod;
                B[i] = (:: a[i] * b[x][i] + :: b[i]) % mod;
                C[i] = (:: c[i] * a[x][i] + c[x][i]) % mod;
                D[i] = (:: c[i] * b[x][i] + d[x][i] + :: d[i]) % mod;
            }
            memcpy(:: a, A, sizeof(:: a));
            memcpy(:: b, B, sizeof(:: b));
            memcpy(:: c, C, sizeof(:: c));
            memcpy(:: d, D, sizeof(:: d));
        }
        else {
            if(mid >= ll)
                query(L, l, mid, ll, rr);
            if(mid < rr)
                query(R, mid + 1, r, ll, rr);
        }
    }
}tr;

void FWT_XOR(int *a, int n, int f) {
    for(int i = 1; i < n; i <<= 1)
        for(int j = 0; j < n; j += (i << 1))
            for(int k = 0; k < i; ++ k) {
                int x = a[j + k], y = a[i + j + k];
                a[j + k] = M(x + y);
                a[i + j + k] = M(x - y);
                if(f == - 1) {
                    a[j + k] = 1ll * a[j + k] * inv[2] % mod;
                    a[i + j + k] = 1ll  * a[i + j + k] * inv[2] % mod;
                }
            }
}
void FWT_XOR_HAND(int *a, int n, int pos) {
    for(int i = 0; i < n; ++ i)
        if(One[i & pos])
            a[i] = mod - 1;
        else
            a[i] = 1;
}

void getMatrix(int x) {
    FWT_XOR_HAND(a, m, 0);
    memset(b, 0, sizeof(b));
    memset(c, 0, sizeof(c));
    memset(d, 0, sizeof(d));
    tr.query(1, 1, n, dfn[top[x]], dfn[low[x]]);
}
void getF() {
    for(int i = 0; i < m; ++ i)
        F[i] = b[i];
}
void getG() {
    for(int i = 0; i < m; ++ i)
        G[i] = d[i];
}

//链剖

void dfs1(int x) {
    siz[x] = 1;
    for(int i = Gph.head[x]; i; i = Gph.n[i]) {
        int t = Gph.t[i];
        if(t == fa[x]) continue;
        fa[t] = x;
        dfs1(t);
        siz[x] += siz[t];
        if(siz[t] > siz[son[x]])
            son[x] = t;
    }
}
void dfs2(int x, int _top) {
    dfn[x] = ++ DFN;
    top[x] = _top;
    FWT_XOR_HAND(H[x], m, v[x]);
    if(son[x]) {
        dfs2(son[x], _top);
        for(int i = Gph.head[x]; i; i = Gph.n[i]) {
            int t = Gph.t[i];
            if(t == fa[x] || t == son[x]) continue;
            dfs2(t, t);
            for(int j = 0; j < m; ++ j)
                if(f[t][j] + 1 == mod)
                    ++ HZ[x][j];
                else
                    H[x][j] = (H[x][j] * (f[t][j] + 1)) % mod;
        }
    }
    else
        low[top[x]] = x;
    tr.ins(1, 1, n, x);
    if(x == top[x]) {
        getMatrix(x); getF(); getG();
        for(int i = 0; i < m; ++ i) {
            g[x][i] = G[i];
            f[x][i] = F[i];
            Ans[i] = M(Ans[i] + G[i]);
            //可以加速...
        }
    }
}

int sta[N];
int TOP;

void change(int x, int key) {
    FWT_XOR_HAND(a, m, v[x]);
    for(int i = 0; i < m; ++ i)
        H[x][i] = H[x][i] * inv[a[i]] % mod;
    v[x] = key;
    FWT_XOR_HAND(a, m, v[x]);
    for(int i = 0; i < m; ++ i)
        H[x][i] = H[x][i] * a[i] % mod;
    tr.ins(1, 1, n, x);
    x = top[x];

    while(x) {
        for(int i = 0; i < m; ++ i)
            Ans[i] = M(Ans[i] - g[x][i]);
        if(fa[x]) {
            int old = fa[x];
            for(int i = 0; i < m; ++ i)
                if(f[x][i] + 1 == mod)
                    -- HZ[old][i];
                else
                    H[old][i] = (H[old][i] * inv[f[x][i] + 1]) % mod;
        }

        getMatrix(x); getF(); getG();
        for(int i = 0; i < m; ++ i) {
            g[x][i] = G[i];
            f[x][i] = F[i];
            Ans[i] = M(Ans[i] + G[i]);
        }
        if(fa[x]) {
            int old = fa[x];
            for(int i = 0; i < m; ++ i)
                if(f[x][i] + 1 == mod)
                    ++ HZ[old][i];
                else
                    H[old][i] = (H[old][i] * (f[x][i] + 1)) % mod;
            tr.ins(1, 1, n, old);
        }
        x = top[fa[x]];
    }
}

char s[10]cpp;

void solve() {
    q = read();
    while(q --) {
        scanf("%s", s);
        if(s[0] == 'Q') {
            FWT_XOR(Ans, m, - 1);
            printf("%d\n", Ans[read()]);
            FWT_XOR(Ans, m, 1);
        }
        else {
            int x = read(), key = read();
            change(x, key);
        }
    }
}

int main() {
    bef(); init();
    dfs1(1); dfs2(1, 1);
    solve();
    return 0;
}
Comments

添加新评论

已有 2 条评论

这外面为什么写的3-29...

Chlience Chlience 回复 @兜兜里有糖

好像是因为VOID自动将时间往后推了一天??