Mst | BZOJ 2238

题目链接
给出 \(n\) 个点,\(m\) 条边的无向带权图,\(q\) 次询问,询问在图中删掉一条边后的 \(\text{MST}\) 的边权和。询问独立。
\(n \leq 5 \times 10^4, m \leq 10^5\)

记原图的 \(\text{MST} = (E_{\text{MST}}, V_{\text{MST}})\)

对于 \(e(u, v, w) \not \in E_{\text{MST}}\)(下文称为非树边),将它删去后显然不会对答案造成任何影响。

对于 \(e(u, v, w) \in E_{\text{MST}}\)(下文称为树边),将它删去后,为了使得点 \(u, v\) 仍然连通,我们必须要找一条非树边代替之,且这条非树边 \(e'(u', v', w')\) 所连接的顶点 \((u', v')\),在 \(\text{MST}\) 上的路径必定覆盖了 \((u, v)\)

自然的,我们想到枚举每一条非树边,并将其所连接的两个节点在 \(\text{MST}\) 上的路径中的所有树边更新。

更具体的,记 \(f_e\)(其中 \(e\) 为一条树边)为能代替 \(e\) 的非树边的最小权值。一开始 \(f_e = +\infty\)。对于枚举到的非树边 \(e'(u', v', w')\),更新所有 \(e \in E'\)(其中 \(E'\) 代表 \((u', v')\)\(\text{MST}\) 上的路径)的 \(f_e \leftarrow \min(f_e, w')\)

问题转化为如何维护这个过程。

一个经典的解法是利用树链剖分与线段树,网络上大多数的题解也是如此。不过这样做的复杂度是 \(O(n \log^2 n)\) 的,且代码长度较长。

我们采用一种码量更少,复杂度更为优秀的 \(O(n \log n)\) 算法,树上倍增来解决。

记录倍增数组 \(\text{fa}(u, k)\) 表示 \(u\)\(2^k\) 级祖先。

令标记 \(\text{tag}(u, k)\) 表示从 \(u\) 到其 \(2^k\) 级祖先的链上被更新的延时标记。易知整个算法就是要回答 \(\text{tag}(u, 0)\)

考虑倍增求 LCA 的过程,同样的,我们不断从 \(u', v'\) 向上跳,直到相遇,同时打上标记即可。

最后将标记下传,即

\[\text{tag}(u, i - 1) \leftarrow \min(\text{tag}(u, i - 1), \text{tag(u, i)})\\ \] \[\text{tag}(\text{fa}(u, i - 1), i - 1) \leftarrow \min(\text{tag}(\text{fa}(u, i - 1), i - 1), \text{tag}(u, i))\]

感性理解起来就是将 \(u\)\(\text{fa}(u, i)\) 的标记下传给上下两半。

至此,对于删除树边 \(e(u, v, w)\),其答案为:

\[ \text{MST}_w - w + \text{tag}(u, 0) \]

(这里我们假设 \(u\)\(\text{MST}\) 上的深度更深一点)。

代码实现上有一些区别。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <bits/stdc++.h>
#define rep(i, l, r) for(int i = (l); i <= (r); i++)
#define per(i, r, l) for(int i = (r); i >= (l); i--)
#define upd(a, b) a = min(a, b)

using namespace std;
const int N = 5e4 + 5, MI = 1e6;
const char* Nc = "Not connected";
int n, m, q, cnt, mst, head[N], d[N], fa[N][16], tag[N][16], dwn[N*2], pr[N], w[N*2];
struct Edge {
int u, v, w, id;
bool operator <(const Edge& b)const { return w < b.w; }
} E[N*2];
struct edge { int v, nxt, id; } e[N*2];

int find(int x) { return x == pr[x] ? x : pr[x] = find(pr[x]); }
void add(int u, int v, int id) {
e[++cnt] = (edge){ v, head[u], id };
head[u] = cnt;
}
void dfs(int u) { // 求 d, fa, dwn 数组, dwn[i] 是第 i 条边的下端点
d[u] = d[fa[u][0]] + 1;
rep(i, 1, 15) fa[u][i] = fa[fa[u][i-1]][i-1];
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].v;
if(v == fa[u][0]) continue;
fa[dwn[e[i].id] = v][0] = u;
dfs(v);
}
}
void update(int u, int v, int w) { // 倍增打标记
if(d[u] < d[v]) swap(u, v);
per(j, 15, 0) if(d[u] - (1 << j) >= d[v])
upd(tag[u][j], w), u = fa[u][j];
if(u == v) return;
per(j, 15, 0) if(fa[u][j] != fa[v][j]) {
upd(tag[u][j], w), upd(tag[v][j], w);
u = fa[u][j], v = fa[v][j];
}
upd(tag[u][0], w); upd(tag[v][0], w);
}

int main() {
scanf("%d%d", &n, &m);
rep(i, 1, m) {
int u, v; scanf("%d%d%d", &u, &v, &w[i]);
E[i] = (Edge){ u, v, w[i], i };
}
sort(E + 1, E + m + 1);
rep(i, 1, n) pr[i] = i;
rep(i, 1, m) {
int f1 = find(E[i].u), f2 = find(E[i].v);
if(f1 == f2) continue;
pr[f2] = f1; mst += E[i].w;
add(E[i].u, E[i].v, E[i].id); add(E[i].v, E[i].u, E[i].id);
}
scanf("%d", &q);
if(cnt / 2 < n - 1) { while(q--) puts(Nc); return 0; }
dfs(1);
memset(tag, 0x3f, sizeof tag);
rep(i, 1, m) if(!dwn[E[i].id]) update(E[i].u, E[i].v, E[i].w);
per(i, 15, 1) rep(j, 1, n) { // 标记下传到底
upd(tag[j][i-1], tag[j][i]);
upd(tag[fa[j][i-1]][i-1], tag[j][i]);
}
while(q--) {
int T; scanf("%d", &T);
if(!dwn[T]) printf("%d\n", mst);
else {
int ans = tag[dwn[T]][0];
if(ans == 0x3f3f3f3f) puts(Nc);
else printf("%d\n", mst + ans - w[T]);
}
}
return 0;
}