【题解】LOJ 2058 「TJOI / HEOI2016」求和

请注意,本文编写于 94 天前,最后修改于 94 天前,其中某些信息可能已经过时。

Problem

$$ f(n)=\sum_{i=0}^{n}\sum_{j=0}^{n}S(i,j)2^jj! $$

Thought

$$ S(i,j)=\frac{1}{j!}\sum_{k=0}^{j}(-1)^{j-k}{j\choose k}k^i\\ f(n)=\sum_{i=0}^{n}\sum_{j=0}^{n}\sum_{k=0}^{j}(-1)^{j-k}{j\choose k}k^i2^jk^i\\ f(n)=\sum_{i=0}^{n}\sum_{j=0}^{n}\sum_{k=0}^{j}(-1)^{j-k}\frac{j!}{k!(j-k)!}2^jk^i\\ f(n)=\sum_{j=0}^{n}j!2^j\sum_{i=0}^{n}\sum_{k=0}^{j}\frac{(-1)^{j-k}}{(j-k)!}\frac{k^i}{k!}\\ f(n)=\sum_{j=0}^{n}j!2^j\sum_{k=0}^{j}\frac{(-1)^{j-k}}{(j-k)!}\frac{\sum_{i=0}^{n}k^i}{k!}\\ $$

设 $A(x)=\frac{(-1)^x}{x!},B(x)=\frac{\sum_{i=0}^{n}x^i}{x!}=\frac{x^{n+1}-1}{x!(x-1)}$

$$ f(n)=\sum_{j=0}^{n}j!2^j\sum_{k=0}^{j}A(j-k)B(k)\\ $$

设 $C=A*B$

$$ f(n)=\sum_{j=0}^{n}j!2^jC(j) $$

时间复杂度 $O(n\log n)$

Code

#include <bits/stdc++.h>
using namespace std;
const int N = 100010;
const int mod = 998244353;
const int G = 3;

int A[N << 2], B[N << 2], C[N << 2];

int rev[N << 2];

int fac[N], inv[N], invPro[N], bin[N];

int n;

int read() {
    int ans = 0, flag = 1;
    char ch = getchar();
    while(ch > '9' || ch < '0') {if(ch == '-') flag = - flag; ch = getchar();}
    while(ch >= '0' && ch <= '9') {ans = ans * 10 + ch - '0'; ch = getchar();}
    return ans * flag;
}

void prepare() {
    fac[0] = invPro[0] = 1;
    fac[1] = inv[1] = invPro[1] = 1;
    bin[0] = 1;
    bin[1] = 2;
    for(int i = 2; i < N; ++ i) {
        inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
        invPro[i] = 1ll * invPro[i - 1] * inv[i] % mod;
        fac[i] = 1ll * fac[i - 1] * i % mod;
        bin[i] = (bin[i - 1] << 1) % mod;
    }
}

int M(int x) {
    while(x >= mod) x -= mod;
    while(x < 0) x += mod;
    return x;
}

void getA() {
    for(int i = 0; i <= n; i ++) {
        A[i] = invPro[i];
        if(i % 2) A[i] = M(- A[i]);
    }
}

int qpow(int a, int b) {
    int ans = 1;
    while(b) {
        if(b & 1) ans = 1ll * ans * a % mod;
        a = 1ll * a * a % mod;
        b >>= 1;
    }
    return ans;
}

void getB() {
    B[0] = 1 * invPro[0] % mod;
    B[1] = (n + 1) * invPro[1] % mod;
    for(int i = 2; i <= n; ++ i) {
        B[i] = M(qpow(i, n + 1) - 1);
        B[i] = 1ll * B[i] * inv[i - 1] % mod * invPro[i] % mod;
    }
}

void dft(int *a, int n, int f) {
    for(int i = 0; i < n; ++ i)
        if(i < rev[i])
            swap(a[i], a[rev[i]]);
    for(int i = 1; i < n; i <<= 1) {
        int gn = qpow(G, (mod - 1) / (i << 1));
        if(f == - 1) gn = qpow(gn, mod - 2);
        for(int j = 0; j < n; j += (i << 1)) {
            int x, y, g = 1;
            for(int k = 0; k < i; ++ k, g = 1ll * g * gn % mod) {
                x = a[j + k]; y = 1ll * g * a[i + j + k] % mod;
                a[j + k] = M(x + y);
                a[i + j + k] = M(x - y);
            }
        }
    }
    if(f == - 1) {
        int ny = qpow(n, mod - 2);
        for(int i = 0; i < n; ++ i)
            a[i] = 1ll * a[i] * ny % mod;
    }
}

void getC() {
    int nn, l = 0;
    for(nn = 1; nn <= (n << 1); nn <<= 1) ++ l;
    for(int i = 0; i < nn; ++ i)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
    dft(A, nn, 1); dft(B, nn, 1);
    for(int i = 0; i < nn; ++ i)
        C[i] = 1ll * A[i] * B[i] % mod;
    dft(A, nn, - 1); dft(B, nn, - 1);
    dft(C, nn, - 1);
}

void getAns() {
    int ans = 0;
    for(int i = 0; i <= n; ++ i)
        ans = M(ans + 1ll * fac[i] * bin[i] % mod * C[i] % mod);
    printf("%d\n", ans);
}

int main() {
    prepare();
    n = read();
    getA(); getB();
    getC();
    getAns();
    return 0;
}
Comments

添加新评论