【题解】BZOJ 2754 [SCOI2012]喵星球上的点名

Problem

传送门

我真是越来越懒了

Thought

如果用 $SAM$ 就很裸了是吧
但是一般来说都用的是 $AC$ 自动机吧

先解决姓名分开匹配的问题,考虑往里面添加一个特殊字符 $10001$ 将姓名隔开,这样保证不会出现在姓名分开匹配的情况

接下来的问题是求出对于每个姓名串,有多少个点名串是其子串

将点名串建出 $AC$ 自动机,直接用姓名串在上面匹配,每匹配到一个节点,相当于是匹配了在 $fail$ 树从根到该节点的所有点(都拥有同样的后缀)

这种东西可以在 $fail$ 树上按照建立虚树的方式统计答案,保证不重不漏
也就是对于匹配点假装建出虚树,统计路径上有多少个点名串的结束位置即可

同样,要求出对于每个点名串,是多少个姓名串的子串,可以和上面一起处理出来
只需要求出每个点名串被覆盖了多少次即可,这个可以通过建虚树时差分的方式解决
最后 $DFS$ 上传标记

注意:由于字符集大小达到了 $10000$ ,需要使用可持久化数组来维护转移,在这里不再赘述

Code

#include <bits/stdc++.h>
using namespace std;
const int N = 100010;

int read() {
    int ans = 0, flag = 1;
    char ch = getchar();
    while(ch > '9' || ch < '0') {if(ch == '-') flag = - flag; ch = getchar();}
    while(ch >= '0' && ch <= '9') {ans = ans * 10 + ch - '0'; ch = getchar();}
    return ans * flag;
}

struct Edge {
    int t, n;
}e[N];
int head[N], Etot;
void addedge(int u, int v) {
    e[++ Etot] = {v, head[u]};
    head[u] = Etot;
}

int dfn[N], DFN;
int fa[N][21], dep[N];

bool cmp(int a, int b) {
    return dfn[a] < dfn[b];
}

int lca(int a, int b) {
    if(dep[a] < dep[b])
        swap(a, b);
    for(int i = 20; i >= 0; -- i)
        if(dep[fa[a][i]] >= dep[b])
            a = fa[a][i];
    if(a == b) return a;
    for(int i = 20; i >= 0; -- i)
        if(fa[a][i] != fa[b][i]) {
            a = fa[a][i];
            b = fa[b][i];
        }
    return fa[a][0];
}

int n, m;
vector <int> name[N];
int str[N];

int Ans[N];

struct AC {
    #define L (ch[x][0])
    #define R (ch[x][1])
    #define mid ((l + r) >> 1)

    struct President {
        int rt[N], ch[20 * N][2], a[20 * N], cnt;
        void ins(int &x, int l, int r, int pos, int son) {
            ++ cnt;
            ch[cnt][0] = ch[x][0];
            ch[cnt][1] = ch[x][1];
            x = cnt;

            if(l == r) a[x] = son;
            else
                if(pos <= mid) ins(L, l, mid, pos, son);
                else ins(R, mid + 1, r, pos, son);
        }
        int query(int x, int l, int r, int pos) {
            if(l == r) return a[x];
            else {
                if(pos <= mid) return query(L, l, mid, pos);
                else return query(R, mid + 1, r, pos);
            }
        }
        int son(int a, int b) {
            return query(rt[a], 0, 10001, b);
        }
        void res(int a, int b, int c) {
            ins(rt[a], 0, 10001, b, c);
        }
    }tr;
    map <int, int> mp[N];
    int end[N];
    int id[N];

    int rt, nodeCnt;
    int fail[N];

    bool vis[N];

    void ins(int *str, int n, int tim) {
        int now = rt;
        for(int i = 1; i <= n; ++ i) {
            if(!mp[now].count(str[i]))
                mp[now][str[i]] = ++ nodeCnt;
            now = mp[now][str[i]];
        }
        id[tim] = now;
        ++ end[now];
    }
    
    queue <int> q;

    void build() {
        for(auto it : mp[rt]) {
            tr.res(rt, it.first, it.second);
            fail[it.second] = rt;
            q.push(it.second);
        }
        while(!q.empty()) {
            int x = q.front(); q.pop();
            addedge(fail[x], x);
            tr.rt[x] = tr.rt[fail[x]];
            for(auto it : mp[x]) {
                tr.res(x, it.first, it.second);
                fail[it.second] = tr.son(fail[x], it.first);
                q.push(it.second);
            }
        }
    }


    int sta[N], top;
    int nod[N], nodCnt;

    int tag[N];

    void dfs1(int x) {
        dfn[x] = ++ DFN; dep[x] = dep[fa[x][0]] + 1;
        for(int i = 1; i <= 20; ++ i)
            fa[x][i] = fa[fa[x][i - 1]][i - 1];
        for(int i = head[x]; i; i = e[i].n)    {
            fa[e[i].t][0] = x;
            end[e[i].t] += end[x];
            dfs1(e[i].t);
        }
    }

    void solve(int tim) {
        int x = rt;
        vis[x] = 1;
        nod[++ nodCnt] = x;
        for(auto it : name[tim]) {
            x = tr.son(x, it);
            if(!vis[x]) {
                vis[x] = 1;
                nod[++ nodCnt] = x;
            }
        }
        sort(nod + 1, nod + nodCnt + 1, cmp);
        sta[top] = - 1;
        sta[++ top] = nod[1];
        for(int i = 2; i <= nodCnt; ++ i) {
            int l = lca(sta[top], nod[i]);
            if(l != sta[top]) {
                while(top >= 2 && dep[l] <= dep[sta[top - 1]]) {
                    tag[sta[top]] += 1;
                    tag[sta[top - 1]] -= 1;
                    Ans[tim] += end[sta[top]];
                    Ans[tim] -= end[sta[top - 1]];
                    sta[top --] = 0;
                }
                if(l != sta[top - 1]) {
                    tag[sta[top]] += 1;
                    tag[l] -= 1;

                    Ans[tim] += end[sta[top]];
                    Ans[tim] -= end[l];

                    sta[top] = l;
                }
                else {
                    tag[sta[top]] += 1;
                    tag[sta[top - 1]] -= 1;
                    Ans[tim] += end[sta[top]];
                    Ans[tim] -= end[sta[top - 1]];
                    sta[top --] = 0;
                }
            }
            sta[++ top] = nod[i];
        }
        tag[sta[top]] += 1;
        tag[sta[1]] -= 1;
        Ans[tim] += end[sta[top]];
        Ans[tim] -= end[sta[1]];
        while(top) sta[top --] = 0;
        while(nodCnt) {
            vis[nod[nodCnt]] = 0;
            nod[nodCnt --] = 0;
        }
    }
    void dfs2(int x) {
        for(int i = head[x]; i; i = e[i].n) {
            dfs2(e[i].t);
            tag[x] += tag[e[i].t];
        }
    }
    void print() {
        for(int i = 1; i <= m; ++ i)
            printf("%d\n", tag[id[i]]);
        for(int i = 1; i <= n; ++ i)
            printf("%d ", Ans[i]);
    }
}ac;

int main() {
    n = read(); m = read();
    for(int i = 1; i <= n; ++ i) {
        int len = read();
        while(len --) name[i].push_back(read());
        name[i].push_back(10001);
        len = read();
        while(len --) name[i].push_back(read());
    }
    for(int i = 1; i <= m; ++ i) {
        int len = read();
        for(int j = 1; j <= len; ++ j)
            str[j] = read();
        ac.ins(str, len, i);
    }
    ac.build();
    ac.dfs1(ac.rt);
    for(int i = 1; i <= n; ++ i)
        ac.solve(i);
    ac.dfs2(ac.rt);
    ac.print();
    return 0;
}
Comments

添加新评论