1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
| #include <bits/stdc++.h> using namespace std; typedef long long ll; const int maxn = 3e5 + 50, lg = (int)(log2(maxn)) + 1; queue<int> q; int n, m, tot, head[maxn], nxt[maxn << 1], ver[maxn << 1], edge[maxn << 1];
int fa[maxn], d[maxn], f[maxn][lg], mx[maxn][lg][2];
struct node { int x, y, z; bool operator<(const node &a) const { return z < a.z; } } e[maxn];
bool v[maxn];
void add(int x, int y, int z) { ver[++tot] = y, nxt[tot] = head[x], edge[tot] = z, head[x] = tot; }
void merge(int c[2], int a[2], int b[2]) { if (a[0] == b[0]) { c[1] = max(a[1], b[1]); } else if (a[0] > b[0]) { c[1] = max(a[1], b[0]); } else { c[1] = max(a[0], b[1]); } c[0] = max(a[0], b[0]); }
void bfs() { d[1] = 1; q.push(1); while (!q.empty()) { int x = q.front(); q.pop(); for (int i = head[x]; i; i = nxt[i]) { int y = ver[i]; if (d[y]) continue; d[y] = d[x] + 1; q.push(y); f[y][0] = x; mx[y][0][0] = edge[i]; for (int j = 1; j < lg; ++j) { f[y][j] = f[f[y][j - 1]][j - 1]; merge(mx[y][j], mx[y][j - 1], mx[f[y][j - 1]][j - 1]); } } } }
int lca(int ans[2], int x, int y) { ans[0] = ans[1] = 0; if (d[x] < d[y]) swap(x, y); for (int i = lg - 1; i >= 0; --i) { if (d[f[x][i]] >= d[y]) { merge(ans, ans, mx[x][i]); x = f[x][i]; } } if (x == y) return x; for (int i = lg - 1; i >= 0; --i) { if (f[x][i] != f[y][i]) { merge(ans, ans, mx[x][i]); merge(ans, ans, mx[y][i]); x = f[x][i], y = f[y][i]; } } merge(ans, ans, mx[x][0]); merge(ans, ans, mx[y][0]); return f[x][0]; }
int get(int x) { return x == fa[x] ? x : (fa[x] = get(fa[x])); }
int main() { cin >> n >> m; int x, y, z; for (int i = 1; i <= m; ++i) { cin >> x >> y >> z; e[i].x = x, e[i].y = y, e[i].z = z; } sort(e + 1, e + 1 + m); for (int i = 1; i <= n; ++i) fa[i] = i; ll mst = 0; for (int i = 1; i <= m; ++i) { x = get(e[i].x), y = get(e[i].y); if (x == y) continue; mst += e[i].z; fa[x] = y; v[i] = 1; add(e[i].x, e[i].y, e[i].z), add(e[i].y, e[i].x, e[i].z); } bfs(); int delta = INT_MAX; for (int i = 1; i <= m; ++i) { if (v[i] || e[i].x == e[i].y) continue; int t[2]; lca(t, e[i].x, e[i].y); if (e[i].z > t[0]) { delta = min(delta, e[i].z - t[0]); } else if (t[1]) { delta = min(delta, e[i].z - t[1]); } } cout << mst + delta << endl; return 0; }
|