Chlience

【算法】快速沃尔什变换
美妙的集合变换What首先让我们来看看最熟悉的乘法卷积然后对于每个 $c$ 做一遍 $FMT$ 就能从 $A[0\...
扫描右侧二维码阅读全文
25
2018/12

【算法】快速沃尔什变换

美妙的集合变换

What

首先让我们来看看最熟悉的乘法卷积

$$ F[n]=\sum_{i+j=n}A[i]*B[j] $$

朴素的方法是进行 $n^2$ 次运算计算出 $F[0\to n]$
或者利用快速傅里叶变换将其优化到 $n\log n$ 级别

同样我们定义集合卷积

$$ F[n]=\sum_{i\bigoplus j=n}A[i]*B[j] $$

朴素的,可以用 $n^2$ 次运算来暴力计算

但是有什么方法能和 $FFT$ 一样快速算出 $F[0\to n]$ 么?
当然是有的,那就是 $FWT$

How

一般来说,我们会遇到这三种形式的集合卷积,分别是

$$ F[n]=\sum_{i|j=n}A[i]*B[j]\\ F[n]=\sum_{i\&j=n}A[i]*B[j]\\ F[n]=\sum_{i\ xor\ j=n}A[i]*B[j] $$

相应的,需要不同的方式来计算

或运算($OR$)

首先,类似于 $FFT$,我们需要将原式转化为某个特殊的形式
在这里,定义:

$$ A'[n]=\sum_{i\subseteq n}A[i] $$

那么可得

$$ A'[n]*B'[n]=\sum_{i\subseteq n}A[i]*\sum_{j\subseteq n}B[j]\\ =\sum_{(i|j)\subseteq n}A[i]*B[j]\\ =\sum_{k\subseteq n}\sum_{(i|j)=k}A[i]*B[j]\\ =\sum_{k\subseteq n}F[k]=F'[n] $$

发现这个过程能够 $O(n)$ 的时间内通过 $A',B'$ 求出 $F'$
所以算法时间瓶颈就在如何快速求出 $A'[n]=\sum_{i\subseteq n}A[i]$ 及其逆运算

和 $FFT$ 类似,同样可以通过分治的方法解决此问题,不同的是 $FFT$ 利用的更多是虚数单位复数根的性质,而 $FWT$ 中需要用到逻辑运算本身的性质

其实就是快速莫比乌斯变换($FMT$)和快速莫比乌斯反演($FMI$)啦

快速莫比乌斯变换($FMT$)

$$ F\to F' $$

考虑将当前的 $N$ 项式 $F'$ 划分为前 $N/2$ 项,记为 $L$,后 $N/2$ 项,记为 $R$

假设已经计算出了 $L,R$,如何将其合并呢?
先看看 $L,R$ 分别代表了啥:

$$ L=\sum_{i\subseteq N/2}F[i]\\ R=\sum_{i\subseteq N/2}F[i+N/2] $$

也就是说,除了求和的 $F$ 位置不相同之外,两者应该是等价的

大概像这样:

fwt1.png

对于每一部分 $L,R$ 都获取了正确的值,但是如果要将其直接合并起来,发现后面 $N/2$ 项应该加上前面 $N/2$ 项对其的贡献

fwt2.png

由于是计算每个位置 $X$ 的子集,那么属于 $R$ 的某个位置 $X$ 和属于 $L$ 的某个位置 $X-N/2$ 一一对应,有且仅有最高位 $N/2$ 不同
并且,对于位置 $X$ 来说,其包含有最高位 $N/2$ 的子集已经计算完毕,只需要计算不包含 $N/2$ 位置的子集即可,也就是 $X-N/2$

所以,合并后前 $N/2$ 位即 $L$ ,后 $N/2$ 位为 $L+R$
递归处理自底向上合并即可

快速莫比乌斯反演($FMI$)

$$ F'\to F $$

之前用了 $A'[n]=\sum_{i\subseteq n}A[i]$,其逆运算为 $A[n]=\sum_{i\subseteq n}(-1)^{|n|-|i|}A'[i]$

除了某些 $-1$ 项,这个东西和前面的 $FMT$ 是一样的

因为 $L,R$ 两者之间仅仅相差了一个最高位,所以对于每个对应位来说,$L,R$ 异号

所以,合并后前 $N/2$ 位为 $L$ ,后 $N/2$ 位为 $R-L$
仍然是递归处理自底向上合并(滑稽

到这里就能够完整的完成 $FWT$ 的异或部分啦!

void FWT_OR(int N , int f) {//f = 1 or - 1
    for(int l = 1 ; l < N ; l <<= 1)//the lenth of (L and R)
        for(int i = 0 ; i < N ; i += (l << 1))//the begin pos of L
            for(int j = i ; j < i + l ; ++ j) {//each pos
                int x = a[j] , y = a[j + l];
                a[j + l] = f * x + y;
            }
}

与运算($AND$)

$$ F\to F' $$

类似的,定义:

$$ A'[n]=\sum_{i\supseteq n} A[i] $$

仍然考虑合并 $L,R$

由于是计算每个位置 $X$ 的子集,那么属于 $R$ 的某个位置 $X$ 和属于 $L$ 的某个位置 $X-N/2$ 一一对应,有且仅有最高位 $N/2$ 不同
并且,对于位置 $X-N/2$ 来说,其不包含最高位 $N/2$ 的子集已经计算完毕,只需要计算包含 $N/2$ 位置的子集即可,也就是 $X$ (我肯定不是被复制过来的)

所以,合并后前 $N/2$ 位即 $L+R$ ,后 $N/2$ 位为 $R$
递归处理自底向上合并即可

$$ F'\to F $$

除了某些 $-1$ 项,这个东西和前面的 $FMT$ 是一样的
因为 $L,R$ 两者之间仅仅相差了一个最高位,所以对于每个对应位来说,$L,R$ 异号

所以,合并后前 $N/2$ 位为 $L-R$ ,后 $N/2$ 位为 $R$
仍然是递归处理自底向上合并(我也肯定不是被复制过来的)

void FWT_AND(int N , int f) {//f = 1 or - 1
    for(int l = 1 ; l < N ; l <<= 1)//the lenth of (L and R)
        for(int i = 0 ; i < N ; i += (l << 1))//the begin pos of L
            for(int j = i ; j < i + l ; ++ j) {//each pos
                int x = a[j] , y = a[j + l];
                a[j] = x + f * y;
            }
}

所以说其实与运算和或运算是一样的吧...

异或运算($XOR$)

不是很清楚,也木有看到比较妙的证明,如果有哪位看到了麻烦告诉我 $qwq$

在这里仅仅贴上构造方法

$$ F\to F' $$

合并后前 $N/2$ 位即 $L+R$ ,后 $N/2$ 位为 $L-R$

$$ F'\to F $$

合并后前 $N/2$ 位即 $\frac{L+R}{2}$ ,后 $N/2$ 位为 $\frac{L-R}{2}$

void FWT_XOR(int N , int f) {//f = 1 or - 1
    for(int l = 1 ; l < N ; l <<= 1)//the lenth of (L and R)
        for(int i = 0 ; i < N ; i += (l << 1))//the begin pos of L
            for(int j = i ; j < i + l ; ++ j) {//each pos
                int x = a[j] , y = a[j + l];
                a[j] = x + y; a[j + l] = x - y;
                if(f == - 1) {a[j] = a[j] * inv2; a[j + l] = a[j + l] * inv2;}
            }
}

WHY

为毛要写这么复(jian)杂(dan)的 $FWT$ 呢?

当然是加速集合卷积啦!

比如说搞一搞 子集卷积 啥的

$$ F[n]=\sum_{i|j=n\&\&|i|+|j|=|n|}A[i]*B[j] $$

当然可以转化为

$$ F[n][c]=\sum_{i|j=n}[|i|+|j|=|c|]A[i]*B[j] $$

处理过后大概就是这样

$$ F'[n][c]=[i+j=c](A'[n][i]+B'[n][j]) $$

然后对于每个 $c$ 做一遍 $FMT$ 就能从 $A[0\to n]$ 得到 $A'[0\to n][0\to c]$,合并 $A',B'$ 后得到 $F'$ 再对于每个 $c$ 做一遍逆变换即可

void FWT_ZJ(int N , int f) {
    for(int l = 1 ; j < N ; l <<= 1)
        for(int i = 0 ; i < N ; i += (l << 1))
            for(int j = i ; j < i + l ; ++ j) {
                for(int k = 0 ; k <= n ; ++ k) {
                    int x = a[j][k] , y = a[j + l][k];
                    a[j + l][k] = f * x + y;
                }
            }
}

板子

Luogu 4717

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
const ll inv2 = 499122177;
ll read() {
    ll ans = 0 , flag = 1;
    char ch = getchar();
    while(ch > '9' || ch < '0') {if(ch == '-') flag = - flag; ch = getchar();}
    while(ch >= '0' && ch <= '9') {ans = ans * 10 + ch - '0'; ch = getchar();}
    return ans * flag;
}
void FWT_OR(ll *a , ll N , ll f) {
    for(int l = 1 ; l < N ; l <<= 1)
        for(int i = 0 ; i < N ; i += (l << 1))
            for(int j = i ; j < i + l ; ++ j) {
                ll x = a[j] , y = a[j + l];
                a[j + l] = (f * x + y + mod) % mod;
            }
}
void FWT_AND(ll *a , ll N , ll f) {
    for(int l = 1 ; l < N ; l <<= 1)
        for(int i = 0 ; i < N ; i += (l << 1))
            for(int j = i ; j < i + l ; ++ j) {
                ll x = a[j] , y = a[j + l];
                a[j] = (x + f * y + mod) % mod;
            }
}
void FWT_XOR(ll *a , ll N , ll f) {
    for(int l = 1 ; l < N ; l <<= 1)
        for(int i = 0 ; i < N ; i += (l << 1))
            for(int j = i ; j < i + l ; ++ j) {
                ll x = a[j] , y = a[j + l];
                a[j] = (x + y) % mod; a[j + l] = (x - y + mod) % mod;
                if(f == - 1) {a[j] = a[j] * inv2 % mod; a[j + l] = a[j + l] * inv2 % mod;}
            }
}
ll n , N , A[200000] , B[200000] , a[200000] , b[200000];
int main() {
    n = read(); N = 1 << n;
    for(int i = 0 ; i < N ; ++ i) A[i] = read();
    for(int i = 0 ; i < N ; ++ i) B[i] = read();

    for(int i = 0 ; i < N ; ++ i) a[i] = A[i];
    for(int i = 0 ; i < N ; ++ i) b[i] = B[i];
    FWT_OR(a , N , 1); FWT_OR(b , N , 1);
    for(int i = 0 ; i < N ; ++ i)
        a[i] = (a[i] * b[i]) % mod;
    FWT_OR(a , N , - 1);
    for(int i = 0 ; i < N ; ++ i) printf("%lld " , a[i]);

    puts("");
    memset(a , 0 , sizeof(a));
    memset(b , 0 , sizeof(b));
    
    for(int i = 0 ; i < N ; ++ i) a[i] = A[i];
    for(int i = 0 ; i < N ; ++ i) b[i] = B[i];
    FWT_AND(a , N , 1); FWT_AND(b , N , 1);
    for(int i = 0 ; i < N ; ++ i)
        a[i] = (a[i] * b[i]) % mod;
    FWT_AND(a , N , - 1);
    for(int i = 0 ; i < N ; ++ i) printf("%lld " , a[i]);
    
    puts("");
    memset(a , 0 , sizeof(a));
    memset(b , 0 , sizeof(b));

    for(int i = 0 ; i < N ; ++ i) a[i] = A[i];
    for(int i = 0 ; i < N ; ++ i) b[i] = B[i];
    FWT_XOR(a , N , 1); FWT_XOR(b , N , 1);
    for(int i = 0 ; i < N ; ++ i)
        a[i] = (a[i] * b[i]) % mod;
    FWT_XOR(a , N , - 1);
    for(int i = 0 ; i < N ; ++ i) printf("%lld " , a[i]);
    return 0;
}

具体例题的花可以看这个很妙的博客 K-XZY

(逃

Last modification:December 25th, 2018 at 08:08 pm

Leave a Comment