LCA_解题技巧

练习题

模板

聚会为例题

倍增

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5 + 50, lg = (int)(log2(maxn)) + 1;
int n, m;
int tot, head[maxn], nxt[maxn << 1], ver[maxn << 1];
int d[maxn], f[maxn][lg];
queue<int> q;

void add(int x, int y) { ver[++tot] = y, nxt[tot] = head[x], head[x] = tot; }

int dis(int x, int y, int z) { return d[x] + d[y] - 2 * d[z]; }

void bfs() {
d[1] = 1, q.push(1);
while (!q.empty()) {
int x = q.front();
q.pop();
for (int i = head[x]; i; i = nxt[i]) {
int y = ver[i];
if (d[y]) continue;
q.push(y);
d[y] = d[x] + 1;
f[y][0] = x;
for (int j = 1; j < lg; ++j) f[y][j] = f[f[y][j - 1]][j - 1];
}
}
}

int lca(int x, int y) {
if (d[x] < d[y]) swap(x, y);
for (int i = lg - 1; i >= 0; --i) {
if (d[f[x][i]] >= d[y]) x = f[x][i];
}
if (x == y) return x;
for (int i = lg - 1; i >= 0; --i) {
if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
}
return f[x][0];
}

int main() {
scanf("%d%d", &n, &m);
int x, y, z;
for (int i = 1; i < n; ++i) {
scanf("%d%d", &x, &y);
add(x, y), add(y, x);
}
bfs();
int pos, cost;
while (m--) {
scanf("%d%d%d", &x, &y, &z);
int xy = lca(x, y), xz = lca(x, z), yz = lca(y, z);
if (d[xz] >= d[xy] && d[xz] >= d[yz]) {
pos = xz;
cost = dis(x, pos, pos) + dis(z, pos, pos) + dis(pos, y, xy);
} else if (d[yz] >= d[xy] && d[yz] >= d[xz]) {
pos = yz;
cost = dis(y, pos, pos) + dis(z, pos, pos) + dis(pos, x, xy);
} else {
pos = xy;
cost = dis(x, pos, pos) + dis(y, pos, pos) + dis(pos, z, xz);
}
printf("%d %d\n", pos, cost);
}
return 0;
}

Tarjan

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5 + 50;
int n, m;
int tot, head[maxn], nxt[maxn << 1], ver[maxn << 1];
int a[maxn], b[maxn], c[maxn];
// xy保存在第一维,xz保存在第二维,yz保存在第三维
int lca[3 * maxn];
// hq对vq的lca,保存在lca[idx]
int tq, hq[maxn], nq[6 * maxn], vq[6 * maxn], idx[6 * maxn];
int fa[maxn], v[maxn], d[maxn];

void add(int x, int y) { ver[++tot] = y, nxt[tot] = head[x], head[x] = tot; }

void add_q(int x, int y, int i) {
// 特判一下,因为v[x] == v[y] == 1时,不会更新lca
if (x == y) {
lca[i] = x;
return;
}
vq[++tq] = y, nq[tq] = hq[x], idx[tq] = i, hq[x] = tq;
vq[++tq] = x, nq[tq] = hq[y], idx[tq] = i, hq[y] = tq;
}

int dis(int x, int y, int z) { return d[x] + d[y] - 2 * d[z]; }

int get(int x) { return x == fa[x] ? x : fa[x] = get(fa[x]); }

void tarjan(int x) {
v[x] = 1;
for (int i = head[x]; i; i = nxt[i]) {
int y = ver[i];
if (v[y]) continue;
d[y] = d[x] + 1;
tarjan(y);
fa[y] = x;
}
for (int i = hq[x]; i; i = nq[i]) {
int y = vq[i];
if (v[y] == 2) lca[idx[i]] = get(y);
}
v[x] = 2;
}

int main() {
scanf("%d%d", &n, &m);
int x, y, z;
for (int i = 1; i < n; ++i) {
scanf("%d%d", &x, &y);
add(x, y), add(y, x);
}
for (int i = 1; i <= n; ++i) fa[i] = i;
for (int i = 1; i <= m; ++i) {
scanf("%d%d%d", &x, &y, &z);
a[i] = x, b[i] = y, c[i] = z;
add_q(x, y, i), add_q(x, z, i + m), add_q(y, z, i + 2 * m);
}
tarjan(1);
int pos, cost;
for (int i = 1; i <= m; ++i) {
int x = a[i], y = b[i], z = c[i];
int xy = lca[i], xz = lca[i + m], yz = lca[i + 2 * m];
if (d[xz] >= d[xy] && d[xz] >= d[yz]) {
pos = xz;
cost = dis(x, pos, pos) + dis(z, pos, pos) + dis(pos, y, xy);
} else if (d[yz] >= d[xy] && d[yz] >= d[xz]) {
pos = yz;
cost = dis(y, pos, pos) + dis(z, pos, pos) + dis(pos, x, xy);
} else {
pos = xy;
cost = dis(x, pos, pos) + dis(y, pos, pos) + dis(pos, z, xz);
}
printf("%d %d\n", pos, cost);
}
return 0;
}