【题解】BZOJ 3771 Triple

请注意,本文编写于 256 天前,最后修改于 256 天前,其中某些信息可能已经过时。

Promlem

传送门 >ω<

题目大意:
有$n$件物品,每件物品有一个权值$a_i$,可以用$1,2,3$个价值不同的物品组合出一个总价值,问每种总价值有多少种组成方案

Solution

既然每种价值的物品只能选一个,那么不用管每种价值有多少个,只用关心有没有就好了。作为一个组合问题,使用普通型生成函数

考虑到直接算答案比较麻烦,利用容斥进行计算

$A(i)$表示选择一件物品的生成函数

$B(i)$表示选择两件相同物品的生成函数

$C(i)$表示选择三件相同物品的生成函数

由容斥原理可得

选择一件物品的贡献:$A(i)$

选择两件物品的贡献:$\frac{A^2(i) - B(i)}{2}$

选择三件物品的贡献:$\frac{A^3(i) - 3 * A * B(i) + 2 * C[i]}{6}$

输出三者系数之和即可


这次的实现用了NTT和快速乘,感觉跑的好慢的说...

代码如下

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = (35ll << 31) + 1 , G = 3 , N = 525000;
ll rev[N] , a[N] , b[N] , c[N] , ANS[N];
ll n , m;
inline ll mul(ll a , ll b) {
    ll d  = (ll) double(a * (double)b / mod + 0.5);
    ll ret = a * b - d * mod;
    if(ret < 0) ret += mod;
    return ret;
}
ll read() {
    ll ans = 0 , flag = 1;
    char ch = getchar();
    while(ch > '9' || ch < '0') {if(ch == '-') flag = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9') {ans = ans * 10 + ch - '0'; ch = getchar();}
    return ans * flag;
}
ll qpow(ll a , ll b) {
    ll ans = 1;
    while(b) {
        if(b & 1) ans = mul(ans , a);
        a = mul(a , a);
        b >>= 1;
    }
    return ans;
}
void dft(ll *now , ll n , ll f) {
    for(ll i = 0 ; i < n ; ++ i)
        if(i < rev[i]) swap(now[i] , now[rev[i]]);
    for(ll i = 1 ; i < n ; i <<= 1) {
        ll gn = qpow(G , (mod - 1) / (i << 1));
        if(f != 1) gn = qpow(gn , mod - 2);
        for(int j = 0 ; j < n ; j += (i << 1)) {
            ll x , y , g = 1;
            for(int k = 0 ; k < i ; ++ k , g = mul(g , gn)) {
                x = now[j + k] , y = mul(now[i + j + k] , g);
                now[j + k] = (x + y) % mod;
                now[i + j + k] = (x - y + mod) % mod;
            }
        }
    }
    if(f != 1) {
        ll ny = qpow(n , mod - 2);
        for(int i = 0 ; i < n ; ++ i) now[i] = mul(now[i] , ny);
    }
}
int main() {
    n = read();
    for(ll i = 0 ; i < n ; ++ i) {
        ll w = read();
        m = max(m , w);
        a[w] = b[w * 2] = c[w * 3] = 1;
    }
    m *= 3 + 1;
    ll nn , l =0;
    for(nn = 1 ; nn < m ; nn <<= 1) ++ l;
    for(int i = 0 ; i < nn ; ++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
    dft(a , nn , 1); dft(b , nn , 1); dft(c , nn , 1);
    ll inv2 = qpow(2 , mod - 2) , inv6 = qpow(6 , mod - 2);

    for(int i = 0 ; i < nn ; ++ i) {
        ANS[i] += a[i];
        if(ANS[i] > mod) ANS[i] -= mod;

        ANS[i] += mul((mul(a[i] , a[i]) - b[i]) , inv2);
        if(ANS[i] > mod) ANS[i] -= mod;

        ANS[i] += mul(((mul(mul(a[i] , a[i]) , a[i]) - mul(mul(3 , a[i]) , b[i]) + mul(2 , c[i])) % mod + mod) % mod , inv6);
        if(ANS[i] > mod) ANS[i] -= mod;
    }

    dft(ANS , nn , -1);
    for(int i = 0 ; i < m ; ++ i)
        if(ANS[i]) printf("%d %lld\n" , i ,ANS[i]);
    return 0;
}
Comments

添加新评论