Chlience

【题解】BZOJ 3992 [SDOI2015]序列统计
Probelm给定一个集合 $S=\{x|x\in[0,m-1],x\in\mathbb{Z}\}$,$m$ 为质...
扫描右侧二维码阅读全文
16
2019/01

【题解】BZOJ 3992 [SDOI2015]序列统计

Probelm

给定一个集合 $S=\{x|x\in[0,m-1],x\in\mathbb{Z}\}$,$m$ 为质数

现在可以生成一个长度为 $n$ 的数列 $a$,$a_i\in S$,问有多少个数列满足 $\prod_{i=1}^{n}a_i\equiv x\pmod m$

Thought

设 $f[i][j]$ 为前 $i$ 个数的乘积为 $j$ 的方案数,显然有

$$ f[i+1][j*k\bmod m]=f[i][j]*f[1][k]\\ f[i+i][j*k\bmod m]=f[i][j]*f[i][k] $$

考虑如何将乘法转化成加法,然后就可以 $FFT/NTT​$ 优化卷积
发现如果搞个原根,将其表示为幂的形式,就可以愉快的用加法了

$$ f[i+1][(j+k)\bmod m]=f[i][j]*f[1][k]\\ f[i+i][(j+k)\bmod m]=f[i][j]*f[i][k] $$

所以就很简单啦!
注意考虑 $x=0$ 的情况,因为要求的数不可能为 $0$ 所以直接忽略即可

Code

#include <bits/stdc++.h>
using namespace std;
int G = 3 , mod = 1004535809;
const int N = 100010;
bool np[N];
int pri[N] , pri_cnt;
int pf[N] , pri_fac;
int n , m , p , S;
int id[N];
int nn , l , rev[N << 2];
int I[N << 2];
int X[N << 2] , Y[N << 2] , Z[N << 2] , ANS[N << 2];
int read() {
    int ans = 0 , flag = 1;
    char ch = getchar();
    while(ch > '9' || ch < '0') {if(ch == '-') flag = - flag; ch = getchar();}
    while(ch >= '0' && ch <= '9') {ans = ans * 10 + ch - '0'; ch = getchar();}
    return ans * flag;
}
int qpow(int a , int b , int c) {
    int ans = 1;
    while(b) {
        if(b & 1) ans = 1ll * ans * a % c;
        a = 1ll * a * a % c;
        b >>= 1;
    }
    return ans;
}
int getPrimitivelRoot(int x) {
    int xx = x;
    for(int i = 2 ; i <= sqrt(x) ; ++ i) {
        if(!np[i])
            pri[++ pri_cnt] = i;
        for(int j = 1 ; j <= pri_cnt && i * pri[j] <= sqrt(x) ; ++ j) {
            np[i * pri[j]] = 1;
            if(i % pri[j] == 0) break;
        }
    }
    int phi = x - 1;
    x = phi;
    for(int i = 1 ; i <= pri_cnt && x >= pri[i] ; ++ i) {
        if(x % pri[i] == 0) {
            pf[++ pri_fac] = pri[i];
            while(x % pri[i] == 0) x /= pri[i];
        }
    }
    if(x != 1) pf[++ pri_fac] = x;
    for(int i = 2 ; ; ++ i) {
        int flag = 1;
        for(int j = 1 ; j <= pri_fac && flag ; ++ j)
            if(qpow(i , phi / pf[j] , xx) == 1)
                flag = 0;
        if(flag) return i;
    }
}
void NTT(int* now , int f) {
    for(int i = 0 ; i < nn ; ++ i)
        if(i < rev[i])
            swap(now[i] , now[rev[i]]);
    for(int i = 1 ; i < nn ; i <<= 1) {
        int gn = qpow(G , (mod - 1) / (i << 1) , mod);
        if(f != 1) gn = qpow(gn , mod - 2 , mod);
        for(int j = 0 ; j < nn ; j += (i << 1)) {
            int x , y , g = 1;
            for(int k = 0 ; k < i ; ++ k , g = 1ll * g * gn % mod) {
                x = now[j + k]; y = 1ll * g * now[i + j + k] % mod;
                now[j + k] = (x + y) % mod;
                now[i + j + k] = (x - y + mod) % mod;
            }
        }
    }
    if(f != 1) {
        int ny = qpow(nn , mod - 2 , mod);
        for(int i = 0 ; i < nn ; ++ i) now[i] = 1ll * now[i] * ny % mod;
    }
}
void times(int *a , int *b) {
    for(int i = 0 ; i < nn ; ++ i) {
        X[i] = a[i];
        Y[i] = b[i];
    }
    NTT(X , 1); NTT(Y , 1);
    for(int i = 0 ; i < nn ; ++ i)
        Z[i] = (1ll * X[i] * Y[i]) % mod;
    NTT(Z , - 1);
}
void qpow(int b) {
    ANS[0] = 1;
    while(b) {
        if(b & 1) {
            times(I , ANS);
            for(int i = 0 ; i < nn ; ++ i) ANS[i] = 0;
            for(int i = 0 ; i < nn ; ++ i) {
                ANS[i % (m - 1)] += Z[i];
                if(ANS[i % (m - 1)] >= mod) ANS[i % (m - 1)] -= mod;
            }
        }
        times(I , I);
        for(int i = 0 ; i < nn ; ++ i) I[i] = 0;
        for(int i = 0 ; i < nn ; ++ i) {
            I[i % (m - 1)] += Z[i];
            if(I[i % (m - 1)] >= mod) I[i % (m - 1)] -= mod;
        }
        b >>= 1;
    }
}
int main() {
    n = read(); m = read(); p = read(); S = read();
    int g = getPrimitivelRoot(m);
    for(int i = 0 ; i < m - 1 ; ++ i)
        id[qpow(g , i , m)] = i;
    for(int i = 1 ; i <= S ; ++ i) {
        int x = read();
        if(x) I[id[x]] = 1;
    }
    for(nn = 1 ; nn <= m * 2 ; nn <<= 1) ++ l;
    for(int i = 0 ; i < nn ; ++ i)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
    qpow(n);
    printf("%d\n" , ANS[id[p]]);
    return 0;
}
Last modification:January 16th, 2019 at 05:31 pm

Leave a Comment