【题解】LOJ 2303 「NOI2017」蚯蚓排队

Problem

有 $n$ 只蚯蚓,其长度 $l_i\in[1,6]$ ,刚开始时每只蚯蚓各自排成只有一只蚯蚓的队伍
有 $3$ 种操作:

  1. 1 i j :将 $j$ 号蚯蚓所在的队伍接在 $i$ 号蚯蚓队伍的后面
  2. 2 i :从 $i$ 号蚯蚓的后面将队伍分成两个
  3. 3 s k :问 $s$ 串中长度为 $k$ 的子串在队伍中出现的次数的乘积

其中:操作 2 次数 $\leq 10^3$ ,操作 3 $\sum|s|\leq10^7$

Thought

需要维护的是一个能够支持 断开合并串 的在 $O(|s|)$ 的时间内查询 $s$ 的长度为 $k$ 子串出现次数的数据结构

20pt

注意到 $k$ 很小,那么断开操作仅仅会影响 $1+2+3+\cdots+49$ 个长度分别为 $2,3,4,\cdots,50$ 的串,共 $1225$ 个
合并操作同样也会新增 $1225$ 个串

利用 hash + unordered_map 压缩字符集即可

时间复杂度 $O(m*1225*\log(50n))$

100pt

发现并不会有那么多合并和删除操作,合并次数为 $n+c$ 次,每次分离的的复杂度为 $1225$ ,总合并的复杂度为 $nk$
可以通过此题

时间复杂度 $O(c*2500+nk)$

Code

手写 hash 100pt

#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long ui;
const int N = 2000010;
const int MOD = 998244353;
const int mod = 11248631;

int read() {
    int ans = 0, flag = 1;
    char ch = getchar();
    while(ch > '9' || ch < '0') {if(ch == '-') flag = - flag; ch = getchar();}
    while(ch >= '0' && ch <= '9') {ans = ans * 10 + ch - '0'; ch = getchar();}
    return ans * flag;
}
//----------
struct Hashtable {
    int head[mod], nxt[20000000], len[20000000], num[20000000], cnt;
    ui key[20000000];
    int get(ui Key, int Len) {
        int pos = head[Key % mod];
        while(pos && (key[pos] != Key || len[pos] != Len))
            pos = nxt[pos];
        return pos;
    }
    void add(ui Key, int Len) {
        ++ cnt;
        nxt[cnt] = head[Key % mod];
        key[cnt] = Key;
        num[cnt] = 1;
        len[cnt] = Len;
        head[Key % mod] = cnt;
    }
}hs;
//----------

ui bin[60];

struct Node {
    int pre, nxt, num;
}p[N];

char s[10000010];

int n, m;
void add(int beg, int pos) {
    ui num = 0;
    bool flag = 0;
    int len = 0;
    for(int i = 1; i <= 50 && beg; ++ i) {
        num = num * 7 + p[beg].num;
        ++ len;
        if(beg == pos) flag = 1;
        if(flag) {
            int pos = hs.get(num, len);
            if(!pos) hs.add(num, len);
            else ++ hs.num[pos];
        }
        beg = p[beg].nxt;
    }
}
void del(int beg, int pos) {
    ui num = 0;
    bool flag = 0;
    int len = 0;
    for(int i = 1; i <= 50 && beg; ++ i) {
        num = num * 7 + p[beg].num;
        ++ len;
        if(beg == pos) flag = 1;
        if(flag) {
            int pos = hs.get(num, len);
            -- hs.num[pos];
        }
        beg = p[beg].nxt;
    }
}

void solve(int k) {
    int n = strlen(s);
    ui num = 0;
    long long ans = 1;
    int len = 0;
    for(int i = 0; i < n&& ans; ++ i) {
        num = (num * 7 + (s[i] - '0'));
        ++ len;
        if(i - k >= 0) {
            num -= (s[i - k] - '0') * bin[k];
            -- len;
        }
        if(i + 1 >= k) {
            int pos = hs.get(num, len);
            if(pos) ans = 1ll * ans * hs.num[pos] % MOD;
            else ans = 0;
        }
    }
    printf("%lld\n", ans);
}

int main() {
    bin[0] = 1;
    for(int i = 1; i <= 50; ++ i)
        bin[i] = bin[i - 1] * 7;
    
    n = read(); m = read();
    for(int i = 1; i <= n; ++ i) {
        p[i].num = read();
        int pos = hs.get(p[i].num, 1);
        if(!pos) hs.add(p[i].num, 1);
        else ++ hs.num[pos];
    }
    while(m --) {
        int opt = read();
        switch(opt) {
            case 1 : {
                int a = read(), b = read();
                p[a].nxt = b; p[b].pre = a;
                int beg = a; add(beg, b);
                for(int i = 2; i < 50 && p[beg].pre; ++ i) {
                    beg = p[beg].pre;
                    add(beg, b);
                }
                break;
            }
            case 2 : {
                int a = read(), b = p[a].nxt;
                int beg = a; del(beg, b);
                for(int i = 2; i < 50 && p[beg].pre; ++ i) {
                    beg = p[beg].pre;
                    del(beg, b);
                }
                p[a].nxt = 0; p[b].pre = 0;
                break;
            }
            case 3 : {
                scanf("%s", s); int k = read();
                solve(k);
                break;
            }
        }
    }
    return 0;
}

使用 unordered_map + 三模数 hash 68pt

#include <bits/stdc++.h>
using namespace std;
typedef unsigned int ui;
const int N = 2000010;
const int MOD = 998244353;

int read() {
    int ans = 0, flag = 1;
    char ch = getchar();
    while(ch > '9' || ch < '0') {if(ch == '-') flag = - flag; ch = getchar();}
    while(ch >= '0' && ch <= '9') {ans = ans * 10 + ch - '0'; ch = getchar();}
    return ans * flag;
}
//----------
struct Key {
    ui first, second, third;
    Key(ui f, ui s, ui t) : first(f), second(s), third(t) {}
};
struct EqualKey {
    bool operator () (const Key &a, const Key &b) const {
        return a.first == b.first && a.second == b.second && a.third == b.third;
    }
};
struct HashFunc {
    ui operator()(const Key &key) const {
        using std :: hash;
        return ((hash<ui>()(key.first) ^ (hash<ui>()(key.second) << 1)) >> 1) ^ (hash<ui>()(key.third) << 1);
    }
};
//----------
unordered_map<Key, int, HashFunc, EqualKey> mp;

ui mod[3] = {19260817, 91248653, 141248629};
ui bin[3][60];

struct Node {
    ui pre, nxt, num;
}p[N];

char s[10000010];
ui num[3];

int n, m;
void add(int beg, int pos) {
    num[0] = num[1] = num[2] = 0;
    bool flag = 0;
    for(int i = 1; i <= 50 && beg; ++ i) {
        for(int j = 0; j < 3; ++ j)
            num[j] = (num[j] * 7 + p[beg].num) % mod[j];
        if(beg == pos) flag = 1;
        if(flag) {
            Key X = {num[0], num[1], num[2]};
            if(mp.count(X))
                mp[X] += 1;
            else
                mp[X] = 1;
        }
        beg = p[beg].nxt;
    }
}
void del(int beg, int pos) {
    num[0] = num[1] = num[2] = 0;
    bool flag = 0;
    for(int i = 1; i <= 50 && beg; ++ i) {
        for(int j = 0; j < 3; ++ j)
            num[j] = (num[j] * 7 + p[beg].num) % mod[j];
        if(beg == pos) flag = 1;
        if(flag) {
            Key X = {num[0], num[1], num[2]};
            if(mp[X] == 1)
                mp.erase(X);
            else
                mp[X] -= 1;
        }
        beg = p[beg].nxt;
    }
}

void solve(int k) {
    int len = strlen(s);
    num[0] = num[1] = num[2] = 0;
    long long ans = 1;
    for(int i = 0; i < len && ans; ++ i) {
        for(int j = 0; j < 3; ++ j)
            num[j] = (num[j] * 7 + (s[i] - '0')) % mod[j];
        if(i - k >= 0) {
            for(int j = 0; j < 3; ++ j) {
                num[j] += mod[j] - (s[i - k] - '0') * bin[j][k] % mod[j];
                if(num[j] >= mod[j]) num[j] -= mod[j];
            }
        }
        if(i + 1 >= k) {
            Key X = {num[0], num[1], num[2]};
            ans = ans * mp[X] % MOD;
        }
    }
    printf("%lld\n", ans);
}

int main() {
    bin[0][0] = bin[1][0] = bin[2][0] = 1;
    for(int i = 1; i <= 50; ++ i)
        for(int j = 0; j < 3; ++ j)
            bin[j][i] = bin[j][i - 1] * 7 % mod[j];
    
    n = read(); m = read();
    for(int i = 1; i <= n; ++ i) {
        p[i].num = read();
        Key X = {p[i].num, p[i].num, p[i].num};
        if(mp.count(X))
            mp[X] += 1;
        else
            mp[X] = 1;
    }
    while(m --) {
        int opt = read();
        switch(opt) {
            case 1 : {
                int a = read(), b = read();
                p[a].nxt = b; p[b].pre = a;
                int beg = a; add(beg, b);
                for(int i = 2; i < 50 && p[beg].pre; ++ i) {
                    beg = p[beg].pre;
                    add(beg, b);
                }
                break;
            }
            case 2 : {
                int a = read(), b = p[a].nxt;
                int beg = a; del(beg, b);
                for(int i = 2; i < 50 && p[beg].pre; ++ i) {
                    beg = p[beg].pre;
                    del(beg, b);
                }
                p[a].nxt = 0; p[b].pre = 0;
                break;
            }
            case 3 : {
                scanf("%s", s); int k = read();
                solve(k);
                break;
            }
        }
    }
    return 0;
}
Comments

添加新评论