【题解】BZOJ 3926 [ZJOI2015]诸神眷顾的幻想乡

BZOJ 3926 [ZJOI2015]诸神眷顾的幻想乡

Problem

给定一棵树,每个点有颜色 $c_i$

求树上本质不同的颜色路径

Thought

由于颜色不多,叶子节点不多

直接将叶子节点提为根节点,建立广义 $SAM$ ,在 $SAM$ 中即可获得所有路径

然后直接算 $SAM$ 中有多少本质不同的子串就好了

Code

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int read() {
    int ans = 0, flag = 1;
    char ch = getchar();
    while(ch > '9' || ch < '0') {if(ch == '-') flag = - flag; ch = getchar();}
    while(ch >= '0' && ch <= '9') {ans = ans * 10 + ch - '0'; ch = getchar();}
    return ans * flag;
}

const int N = 100010;
struct Graph {
    int t[N << 1], n[N << 1], head[N], tot;
    void add(int u, int v) {
        ++ tot;
        t[tot] = v;
        n[tot] = head[u];
        head[u] = tot;
    }
}G;
int n, m;
int deg[N];
int col[N];

struct SAM {
    int ch[40 * N][10], cnt = 1;
    int fa[40 * N], l[40 * N];
    ll ans[40 * N], d[40 * N];
    int ins(int c, int last) {
        int x = last, nx = ++ cnt; l[nx] = l[x] + 1;
        for(; x && !ch[x][c]; x = fa[x]) ch[x][c] = nx;
        if(!x) fa[nx] = 1;
        else {
            int y = ch[x][c];
            if(l[x] + 1 == l[y]) fa[nx] = y;
            else {
                int ny = ++ cnt; fa[ny] = fa[y]; l[ny] = l[x] + 1;
                memcpy(ch[ny], ch[y], sizeof(ch[ny]));
                for(; ch[x][c] == y; x = fa[x]) ch[x][c] = ny;
                fa[y] = fa[nx] = ny;
            }
        }
        return nx;
    }
    void dfs(int x) {
        d[x] = 1;
        for(int i = 0; i < m; ++ i)
            if(ch[x][i]) {
                dfs(ch[x][i]);
                d[x] += d[ch[x][i]];
                ans[x] += ans[ch[x][i]] + d[ch[x][i]];
            }
    }
    void out() {
        ll Ans = 0;
        for(int i = 1; i <= cnt; ++ i)
            Ans += (l[i] - l[fa[i]]);
        printf("%lld\n", Ans);
    }
}sam;

void dfs(int x, int last, int _f) {
    last = sam.ins(col[x], last);
    for(int i = G.head[x]; i; i = G.n[i])
        if(G.t[i] != _f)
            dfs(G.t[i], last, x);
}
int main() {
    n = read(); m = read();
    for(int i = 1; i <= n; ++ i)
        col[i] = read();
    for(int i = 1; i < n; ++ i) {
        int u = read(), v = read();
        G.add(u, v); G.add(v, u);
        ++ deg[u]; ++ deg[v];
    }
    for(int i = 1; i <= n; ++ i)
        if(deg[i] == 1)
            dfs(i, 1, 0);
    sam.out();
    return 0;
}
Comments

添加新评论