1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
| #include <bits/stdc++.h> using namespace std; typedef long long ll; const int maxn = 2e5 + 50, lg = (int)(log2(maxn)) + 1; int n, m; int tot, head[maxn], nxt[maxn << 1], ver[maxn << 1]; int d[maxn], f[maxn][lg]; ll s[maxn]; queue<int> q;
void add(int x, int y) { ver[++tot] = y, nxt[tot] = head[x], head[x] = tot; }
void bfs() { d[1] = 1, q.push(1); while (!q.empty()) { int x = q.front(); q.pop(); for (int i = head[x]; i; i = nxt[i]) { int y = ver[i]; if (d[y]) continue; d[y] = d[x] + 1; q.push(y); f[y][0] = x; for (int j = 1; j < lg; ++j) f[y][j] = f[f[y][j - 1]][j - 1]; } } }
int lca(int x, int y) { if (d[x] < d[y]) swap(x, y); for (int i = lg - 1; i >= 0; --i) { if (d[f[x][i]] >= d[y]) x = f[x][i]; } if (x == y) return x; for (int i = lg - 1; i >= 0; --i) { if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i]; } return f[x][0]; }
void dfs(int x, int f) { for (int i = head[x]; i; i = nxt[i]) { int y = ver[i]; if (y == f) continue; dfs(y, x); s[x] += s[y]; } }
int main() { scanf("%d%d", &n, &m); int x, y; for (int i = 1; i < n; ++i) { scanf("%d%d", &x, &y); add(x, y), add(y, x); } bfs(); for (int i = 1; i <= m; ++i) { scanf("%d%d", &x, &y); s[x] += 1, s[y] += 1, s[lca(x, y)] -= 2; } dfs(1, 0); ll ans = 0; for (int i = 2; i <= n; ++i) { if (s[i] == 0) ans += m; else if (s[i] == 1) ans += 1; } printf("%lld\n", ans); return 0; }
|