【题解】BZOJ 4566 [HAOI2016]找相同字符

Problem

给定两个串,求其相同子串对数

Thoguht

由于需要求出 $A$ 的所有子串和 $B$ 的所有子串之间的匹配, $AC$ 自动机这种只能做前缀最大匹配后缀的算法就又有办法处理了

考虑使用 $SAM$
首先对 $A$ 串建出 $SAM$

发现如果在 $SAM$ 中在线插入 $B$ 串,那么 $B$ 串的新增的每一个子串都会是在 $parent$ 树上从根节点到当前节点的一段

那么只需要算出 $A$ 串在这一段中出现多少次,即有多少个子串何其匹配

$DFS$ 将 $A$ 串的所有子串以及子串数量在 $parent$ 树上更新好(传递 $siz$ ),然后插入 $B$ 串时相当于对 $rt-x$ 这条路径求和,也就是说只需要求出原 $parent$ 树每个节点到根的子串数量前缀和即可
可以在 $DFS$ 时一起处理出来

对于 $B$ 串的插入导致的新节点,若是新建节点,那么直接从父亲处继承;对于之前节点的复制,考虑重新计算前缀和,显然可以通过该点和其父亲 $O(1)$ 的计算

最终答案就是插入完 $B$ 串的答案之和

Code

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int read() {
    int ans = 0, flag = 1;
    char ch = getchar();
    while(ch > '9' || ch < '0') {if(ch == '-') flag = - flag; ch = getchar();}
    while(ch >= '0' && ch <= '9') {ans = ans * 10 + ch - '0'; ch = getchar();}
    return ans * flag;
}

const int N = 200010;

struct SAM {
    int ch[N << 2][26], fa[N << 2], siz[N << 2], l[N << 2];
    ll sum[N << 2];
    int last = 1, cnt = 1;

    void ins(int c) {
        int x = last, nx = ++ cnt; l[nx] = l[x] + 1; siz[nx] = 1; last = nx;
        for(; x && !ch[x][c]; x = fa[x]) ch[x][c] = nx;
        if(!x) fa[nx] = 1;
        else {
            int y = ch[x][c];
            if(l[x] + 1 == l[y])
                fa[nx] = y;
            else {
                int ny = ++ cnt; l[ny] = l[x] + 1; fa[ny] = fa[y];
                memcpy(ch[ny], ch[y], sizeof(ch[ny]));
                for(; x && ch[x][c] == y; x = fa[x]) ch[x][c] = ny;
                fa[y] = fa[nx] = ny;
            }
        }
    }

    ll que(int c) {
        int x = last, nx = ++ cnt; l[nx] = l[x] + 1; last = nx;
        for(; x && !ch[x][c]; x = fa[x]) ch[x][c] = nx;
        if(!x)
            fa[nx] = 1;
        else {
            int y = ch[x][c];
            if(l[x] + 1 == l[y])
                fa[nx] = y;
            else {
                int ny = ++ cnt; l[ny] = l[x] + 1; fa[ny] = fa[y]; siz[ny] = siz[y];
                sum[ny] = sum[fa[ny]] + 1ll * (l[ny] - l[fa[ny]]) * siz[ny];
                memcpy(ch[ny], ch[y], sizeof(ch[ny]));
                for(; x && ch[x][c] == y; x = fa[x]) ch[x][c] = ny;
                fa[y] = fa[nx] = ny;
            }
        }
        sum[nx] += sum[fa[nx]];
        return sum[nx];
    }
    void insert(char *str) {
        int len = strlen(str); last = 1;
        for(int i = 0; i < len; ++ i)
            ins(str[i] - 'a');
    }
    void query(char *str) {
        int len = strlen(str); last = 1;
        ll ans = 0;
        for(int i = 0; i < len; ++ i)
            ans += que(str[i] - 'a');
        printf("%lld\n", ans);
    }
}sam;
int sor[N << 2];

bool cmp(int a, int b) {
    return sam.l[a] > sam.l[b];
}

void build(SAM *sam) {
    for(int i = 1; i <= sam -> cnt; ++ i)
        sor[i] = i;
    sort(sor + 1, sor + sam -> cnt + 1, cmp);
    for(int i = 1; i <= sam -> cnt; ++ i) {
        int s = sor[i], f = sam -> fa[s];
        sam -> siz[f] += sam -> siz[s];
    }
    for(int i = sam -> cnt; i >= 1; -- i) {
        int s = sor[i], f = sam -> fa[s];
        sam -> sum[s] += sam -> sum[f];
        sam -> sum[s] += 1ll * (sam -> l[s] - sam -> l[f]) * sam -> siz[s];
    }
}
char str[N];
int main() {
    scanf("%s", str);
    sam.insert(str);
    build(&sam);
    scanf("%s", str);
    sam.query(str);
    return 0;
}
Comments

添加新评论