Chlience

【题解】LOJ 2496. 「AHOI / HNOI2018」毒瘤
Problem给定 $n$ 个点 $m$ 条边的联通图,求有多少种不同的选择方案使得两两点间不相邻$n-1\leq...
扫描右侧二维码阅读全文
12
2018/12

【题解】LOJ 2496. 「AHOI / HNOI2018」毒瘤

Problem

给定 $n$ 个点 $m$ 条边的联通图,求有多少种不同的选择方案使得两两点间不相邻

$n-1\leq m\leq n+10$

Solution

因为这个 $n,m$ 实在是太接近了,估计有些什么蛇皮操作

如果是一棵树的话那就 $DP$ 吧,设 $f[i][0]$ 为没有限制的总方案, $f[i][1]$ 为不选 $i$ 号节点的总方案

$f[i][0] = \prod f[son[i]][0]$
$f[i][1]=\prod f[son[i]][1]$

每加入一条边可能从一个点连向某个祖先,或者是连向另外的子树,需要考虑这两种情况

先提供一个很 $naive$ 的想法:枚举边所连接的 $2*(m-n+1)$ 个点,然后暴力更新
具体更新方法为枚举 $2^{2(m-n+1)}$ 种排列,每种排列可以 $O(n)$ 解决
时间复杂度 $O(2^{2(m-n+1)}n)$ 可以通过 $55pt$ 的数据(话说这个暴力分咋这么低?)

然鹅,当 $m-n+1=11$ 时就会导致枚举部分复杂度太高,所以挂掉了 qwq

正解是怎样的呢?

因为每次 $DP$ 复杂度实在不够优秀,并且其实大部分的点都是没有啥用的,所以考虑使用虚树来降低每次 $DP​$ 的复杂度

显然建树不必多说,利用单调栈即可

接下来考虑两个节点之间的转移:

显然两个相邻节点之间只有四种状态 $(0,0),(1,0),(0,1),(1,1)$ 因为要将点缩起来所以上面选点的状态会影响下面锁点后的状态的最上面节点选 $or$ 不选

设 $k0[i][0]$ 表示第 $i$ 个节点(特殊)最上面的节点 不选,当前特殊节点 不选 的方案数
设 $k1[i][0]$ 表示第 $i$ 个节点(特殊)最上面的节点 ,当前特殊节点 不选 的方案数
设 $k0[i][1]$ 表示第 $i$ 个节点(特殊)最上面的节点 不选,当前特殊节点 的方案数
设 $k1[i][1]$ 表示第 $i$ 个节点(特殊)最上面的节点 ,当前特殊节点 的方案数

那么就将原图中仅仅只有根节点为关键节点的地方往上缩
关键节点上面的部分也往下缩

对于非关键点会有转移

$f[i][0]=\prod (f[son[i]][1]+f[son[i]][0])$
$f[i][1]=\prod f[son[i]][0]$

对于关键点,下面的点可以直接合并到 $f$ 上
然后上面的点合并到 $k$ 里面

具体合并方法,可以一步一步往上爬,每次假设上面就是某个关键点

注意若这个点有其他子树,里面肯定是没有关键点的(否则这个点作为 $lca​$ 也应该是关键点)
那么可以直接用乘法合并一下

第一步因为最上面现在就是自己,那么 $k0[x][0]=1,k1[x][1]=1,k0[x][1]=0,k1[x][0]=0$
接下来继续合并上方一个点,考虑计算出当前点除了这个子树外的贡献 $f[i][0],f[i][1]$

那么有

$k0[x][0]=(k0[x][0]+k1[x][0])*f[i][0]$
$k0[x][1]=(k0[x][1]+k1[x][1])*f[i][0]$
$k1[x][0]=k0[x][0]*f[i][1]$
$k1[x][1]=k0[x][1]*f[i][1]​$

搞出来美滋滋

最终一样的枚举状态 $DP$ 即可

虚树上 $DP$ 方程为:

$f0 = k1[to][0] * g[to][0] + k1[to][1] * g[to][1]$ 代表该儿子最上面点不选
$f1 = k0[to][0] * g[to][0] + k0[to][1] * g[to][1]$ 代表该儿子最上面点选

那么 $g[x][0] = g[x][0] * (f0 + f1),g[x][1] = g[x][1] * f1$

完美解决~撒花~

Code

#include <bits/stdc++.h>
using namespace std;
int read() {
    int ans = 0 , flag = 1;
    char ch = getchar();
    while(ch > '9' || ch < '0') {if(ch == '-') flag = - flag; ch = getchar();}
    while(ch >= '0' && ch <= '9') {ans = ans * 10 + ch - '0'; ch = getchar();}
    return ans * flag;
}
typedef long long LL;
const int N = 100020;
const LL mod = 998244353;
struct edge {
    int t[N << 1] , n[N << 1];
    bool ban[N << 1];
    int head[N] , tot;
    void addedge(int u , int v) {
        ++ tot;
        t[tot] = v;
        n[tot] = head[u];
        head[u] = tot;
    }
}t1 , t2;
int n , m;
bool vis[N];
int dfn[N] , DFN;
int fa[N][23] , dep[N];

int U[23] , V[23] , num;
int p[23] , NUM;
bool in[N];//key point;
int sta[N] , top;//key sta

LL k0[N][2] , k1[N][2];
LL f[N][2] , g[N][2];
LL bin[23];
int mu[N];

void dfs(int x , int ff) {
    fa[x][0] = ff; dep[x] = dep[ff] + 1;
    dfn[x] = ++ DFN;
    vis[x] = 1;
    for(int i = 1 ; i <= 20 ; ++ i)
        fa[x][i] = fa[fa[x][i - 1]][i - 1];
    for(int i = t1.head[x] ; i ; i = t1.n[i])
        if(t1.t[i] != ff && !t1.ban[i]) {
            if(!vis[t1.t[i]])
                dfs(t1.t[i] , x);
            else {
                ++ num;
                U[num] = x;
                V[num] = t1.t[i];
                t1.ban[i] = t1.ban[i ^ 1] = 1;
            }
        }
        
}
bool cmp(int x , int y) {return dfn[x] < dfn[y];}
int lca(int x , int y) {
    if(dep[x] < dep[y]) swap(x , y);
    for(int i = 20 ; i >= 0 ; -- i)
        if(dep[fa[x][i]] >= dep[y])
            x = fa[x][i];
    if(x == y) return x;
    for(int i = 20 ; i >= 0 ; -- i)
        if(fa[x][i] != fa[y][i]) {
            x = fa[x][i];
            y = fa[y][i];
        }
    return fa[x][0];
}
void build() {
    for(int i = 1 ; i <= num ; ++ i) {
        if(!in[U[i]]) {in[U[i]] = 1; p[++ NUM] = U[i];}
        if(!in[V[i]]) {in[V[i]] = 1; p[++ NUM] = V[i];}
    }
    sort(p + 1 , p + NUM + 1 , cmp);
    in[1] = 1; sta[++ top] = 1;
    for(int i = 1 ; i <= NUM ; ++ i) {
        if(p[i] == 1) continue;
        int l = lca(p[i] , sta[top]);
        if(l != sta[top]) {
            while(dfn[l] < dfn[sta[top - 1]]) {
                t2.addedge(sta[top - 1] , sta[top]);
                sta[top --] = 0;
            }
            t2.addedge(l , sta[top]);
            sta[top --] = 0;
            if(sta[top] != l) {
                sta[++ top] = l;
                in[l] = 1;
            }
        }
        sta[++ top] = p[i];
    }
    while(top > 1) {
        t2.addedge(sta[top - 1] , sta[top]);
        sta[top --] = 0;
    }
}

void getf(int x , int disable) {
    f[x][0] = f[x][1] = 1; in[x] = 1;
    for(int i = t1.head[x] ; i ; i = t1.n[i]) {
        int to = t1.t[i];
        if(in[to] || to == fa[x][0] || t1.ban[i] || to == disable) continue;
        getf(to , disable);
        f[x][0] = f[x][0] * (f[to][0] + f[to][1]) % mod;
        f[x][1] = f[x][1] * f[to][0] % mod;
    }
}
void getk(int x , int top) {
    k0[x][0] = k1[x][1] = 1;
    for(int i = x ; fa[i][0] != top ; i = fa[i][0]) {
        getf(fa[i][0] , i);
        LL k00 = k0[x][0] , k01 = k0[x][1];
        k0[x][0] = (k0[x][0] + k1[x][0]) * f[fa[i][0]][0] % mod;
        k0[x][1] = (k0[x][1] + k1[x][1]) * f[fa[i][0]][0] % mod;
        k1[x][0] = k00 * f[fa[i][0]][1] % mod;
        k1[x][1] = k01 * f[fa[i][0]][1] % mod;
    }
}
void prepare(int x) {
    for(int i = t2.head[x] ; i ; i = t2.n[i]) {
        prepare(t2.t[i]);
        getk(t2.t[i] , x);
        //merge t2.t[i] to x;
        //at the same time, please mark the point
    }
    //cal the in[x] != 1 part
    f[x][1] = f[x][0] = 1;
    for(int i = t1.head[x] ; i ; i = t1.n[i]) {
        int to = t1.t[i];
        if(in[to] || to == fa[x][0] || t1.ban[i]) continue;
        getf(to , 0);
        f[x][0] = f[x][0] * (f[to][0] + f[to][1]) % mod;
        f[x][1] = f[x][1] * f[to][0] % mod;
    }
}
void dp(int x) {
    g[x][0] = f[x][0]; g[x][1] = f[x][1];
    for(int i = t2.head[x] ; i ; i = t2.n[i]) {
        int to = t2.t[i];
        dp(to);
        LL f0 = (k1[to][0] * g[to][0] % mod + k1[to][1] * g[to][1] % mod) % mod;
        LL f1 = (k0[to][0] * g[to][0] % mod + k0[to][1] * g[to][1] % mod) % mod;
        g[x][0] = g[x][0] * (f0 + f1) % mod;
        g[x][1] = g[x][1] * f1 % mod;
    }
    if(mu[x] == 1) g[x][0] = 0;
    if(mu[x] == - 1) g[x][1] = 0;
}
int main() {
    n = read(); m = read(); t1.tot = 1;
    for(int i = 1 ; i <= m ; ++ i) {
        int u = read() , v = read();
        t1.addedge(u , v); t1.addedge(v , u);
    }
    dfs(1 , 0);
    build();
    prepare(1);
    LL ans = 0;
    bin[1] = 1;
    for(int i = 2 ; i <= NUM + 1 ; ++ i) bin[i] = bin[i - 1] << 1;
    for(int i = 0 ; i < bin[NUM + 1] ; ++ i) {
        for(int j = 1 ; j <= NUM ; ++ j) {
            if(i & bin[j]) mu[p[j]] = 1;
            else mu[p[j]] = - 1;
        }
        int flag = 0;
        for(int j = 1 ; j <= num ; ++ j) {
            if(mu[U[j]] == 1 && mu[V[j]] == 1) {
                flag = 1;
                break;
            }
        }
        if(flag) continue;
        dp(1);
        ans = (ans + g[1][0] + g[1][1]) % mod;
    }
    printf("%lld\n" , ans);
    return 0;
}
Last modification:February 11th, 2019 at 08:36 am

Leave a Comment