【算法】扩展KMP

请注意,本文编写于 205 天前,最后修改于 122 天前,其中某些信息可能已经过时。

扩展KMP通常用来解决这样的问题:

给定字符串 $S , T$ ,求 $S$ 中每一个后缀与 $T$ 的最长公共前缀

看起来和KMP风马牛不相及的算法为何叫做扩展KMP呢?

发现如果 $S$ 中某个后缀的匹配长度恰好等于 $T$ 的长度,这就是KMP问题,所以一般将其称为扩展KMP算法

现在有字符串 $S , T$
其中 $strlen(S) = n , strlen(T) = m$
定义 $extend[i]$ 为 $s[i,n-1]$ 和 $T$ 的最长公共前缀

所以我们要做的就是求出所有的 $extend$

为了计算 $extend$ 我们引入辅助数组 $nxt$

i> 注意:这个 $nxt$ 和 KMP算法中的 $nxt$ 是不相同的

扩展KMP中 $nxt[i]$ 指 $T[i,m-1]$ 和 $T$ 的最长公共前缀长度
而KMP中 $nxt[i]$ 指 $T$ 中以 $i$ 位置结尾的子串与 $T$ 的最长匹配长度

举一个例子:

1
1

首先在计算 $extend[0]$ 时,显然需要匹配 $3$ 次,得到 $extend[0] = 3$
接下来计算 $extend[1]$ ,由于我们已知 $extend[0] = 3$ ,即 $S[0 , 2] = T[0 , 2]$
那么从 $S[1]$ 开始匹配,$S[1,2]$ 和 $T[0,1]$ 的匹配能否快速计算呢?
现在已知 $S[1,2] = T[1,2]$ ,回忆一下 $nxt[i]$ 的作用:由 $nxt[1] = 2$ 可知 $T[1, 2] = T[0 , 1]$
那么就可以以 $T[1 , 2]$ 为跳板加速与 $T$ 前缀的匹配,即 $S[1,2] = T[0 , 1]$ ,接下来匹配 $S[3]$ 和 $T[2]$ 即可
在这时发生失配,所以 $extend[1] = 2$

这样就节约了很多时间

一般步骤

首先求出 $nxt$ ,即 $T$ 数组自己的每个后缀和自己的最长公共前缀长度
显然, $nxt[0] = m$
接下来,暴力计算 $nxt[1]$
那么后面的 $nxt$ 就可以利用前面已经求得的进行匹配

假设能够匹配到的最右位置为 $r$ ,并且匹配的开始位置为 $pos$ ,现在匹配的位置为 $i$
换句话说 $pos + nxt[pos] - 1 = r$
显然能够得到匹配 $T[0 , r - pos] = T[pos , r]$

2
2

既然这样,那么 $T[i,r]$ 也应该和 $T[i-pos , r - pos]$ 相等,所以说其 $nxt[i]$ 就能够从 $nxt[i - pos]$ 处继承

如果 $i + nxt[i - pos]$ 超过了 $r$ ,那么后面的部分显然是不相等的(否则 $nxt[pos]$ 或者说 $r$ 应该继续向右扩展)

而如果 $i + nxt[i - pos]$ 比 $r$ 更短,那么后面的部分也不会相等

但是如果 $i + nxt[i - pos]$ 恰好等于 $r$ ,后面的匹配情况就是未知的,需要手动向后扩展

所以

nxt[i] = min(nxt[i - pos] , r - i + 1);
while(T[i + nxt[i]] ==  T[nxt[i]]) ++ nxt[i];

显然的,每一个位置只会被额外覆盖一遍,所以说时间复杂度是线性的

现在我们解决了 $nxt$ ,考虑如何求解 $extend$

既然我们是要求 $S$ 的每个后缀和 $T$ 的前缀的最长匹配,而现在我们求出了 $T$ 的每个后缀和自己的前缀的最长匹配
直接按照上面 $T$ 自己匹配的方法套就好了

3
3

$S[i , r] = T[i - pos , r - pos]$
可以从 $nxt[i - pos]$ 处继承
同理,超出部分自行匹配

extend[i] = min(nxt[i - pos] , r - i + 1);
while(S[i + extend[i]] ==  T[extend[i]]) ++ extend[i];

这样就可以在线性时间内求出 $S$ 的每个后缀和 $T$ 前缀的匹配情况了

Too easy, right? -boshi

代码

#include <bits/stdc++.h>
using namespace std;
const int N = 1000010;
char S[N] , T[N];
int nxt[N];
int extend[N];
void get_nxt(char *s) {
    int n = strlen(s) , r = 0 , pos = 1;
    nxt[0] = n;
    while(r + 1 < n && s[r] == s[r + 1]) ++ r;
    nxt[1] = r;
    for(int i = 2 ; i < n ; ++ i) {
        nxt[i] = max(0 , min(nxt[i - pos] , r - i + 1));
        while(i + nxt[i] < n && s[i + nxt[i]] == s[nxt[i]]) ++ nxt[i];
        if(i + nxt[i] - 1 > r) {
            r = i + nxt[i] - 1;
            pos = i;
        }
    }
}
void get_extend(char *s , char *t) {
    int n = strlen(s) , m = strlen(t) , pos = 0 , r;
    while(s[extend[0]] == t[extend[0]]) ++ extend[0];
    r = extend[0] - 1;
    for(int i = 1 ; i < n ; ++ i) {
        extend[i] = max(0 , min(nxt[i - pos] , r - i + 1));
        while(i + extend[i] < n && extend[i] < m && s[i + extend[i]] ==  t[extend[i]]) ++ extend[i];
        if(i + extend[i] - 1 > r) {
            r = i + extend[i] - 1;
            pos = i;
        }
    }
    for(int i = 0 ; i < n ; ++ i)
        printf("%d " , extend[i]);
}
int main() {
    scanf("%s" , T);
    scanf("%s" , S);
    get_nxt(T);
    get_extend(S , T);
    return 0;
}
Comments

添加新评论