P1429. 平面最近点对

考点

  • 分治

题解

见思路

思路

分治的经典教学题,OIWiki已经讲得很好了,见链接;其中的复杂度证明部分一定要理解!

下面讲一下编程细节:

sort实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5 + 50;
int n;

struct node {
double x_, y_;
} arr[maxn], t[maxn];

bool cmpx(node &a, node &b) {
return a.x_ < b.x_ || (a.x_ == b.x_ && a.y_ < b.y_);
}

bool cmpy(node &a, node &b) {
return a.y_ < b.y_ || (a.y_ == b.y_ && a.x_ < b.x_);
}

double calc(node &a, node &b) {
return sqrt((a.x_ - b.x_) * (a.x_ - b.x_) + (a.y_ - b.y_) * (a.y_ - b.y_));
}

double merge(int l, int r) {
if (l == r) return INT_MAX;
int cnt = 0, mid = (l + r) / 2;
double mi = min(merge(l, mid), merge(mid + 1, r));
for (int i = l; i <= r; ++i) {
if (fabs(arr[i].x_ - arr[mid].x_) < mi) t[++cnt] = arr[i];
}
// 一定要按y轴排序!
sort(t + 1, t + 1 + cnt, cmpy);
for (int i = 1; i < cnt; ++i) {
for (int j = i + 1; j <= cnt && fabs(t[i].y_ - t[j].y_) < mi; ++j) {
mi = min(mi, calc(t[i], t[j]));
}
}
return mi;
}

int main() {
cin >> n;
for (int i = 1; i <= n; ++i) cin >> arr[i].x_ >> arr[i].y_;
sort(arr + 1, arr + 1 + n, cmpx);
printf("%.4lf", merge(1, n));
}

重点讲一下为什么必须要让新数组t对y轴排序,这也是这部分代码为什么不是平方数量级的原因。

1
2
3
4
5
for (int i = 1; i < cnt; ++i) {
for (int j = i + 1; j <= cnt && fabs(t[i].y_ - t[j].y_) < mi; ++j) {
mi = min(mi, calc(t[i], t[j]));
}
}

因为x、y轴方向上都满足的点个数总是常数,其余再多的点都是白瞎,你要想办法去除它们的干扰。

只要在范围内的就判断距离,不在范围内的直接进入下一次循环就好了。

为了满足这种单调性,你必须对y轴排序;乱序你还怎么筛呢?

归并实现

归并实现的时间复杂度为\(O\left( n\log n \right)\),上面sort实现的时间复杂度为\(O\left( n\log ^2n \right)\)

因为sort实现方法每次都调用了sort函数排序,归并实现时主体就是归并,不需要额外排序。

要注意!sort实现时,我们是新开了一个数组t来保存新的集合,并不会对原数组造成影响;

归并实现时,就是对原数组进行归并排序,原先的中值mid位置就乱了!

所以在排序前,一定要先保存好原始的中值

1
2
3
4
5
pivot = arr[mid].x_
···
for (int i = l; i <= r; ++i) {
if (fabs(arr[i].x_ - pivot) < mi) t[++cnt] = arr[i];
}

这样才不会影响后续对x轴方向集合的判断!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5 + 50;
int n;

struct node {
double x_, y_;
} arr[maxn], t[maxn], tt[maxn];

double calc(node &a, node &b) {
return sqrt((a.x_ - b.x_) * (a.x_ - b.x_) + (a.y_ - b.y_) * (a.y_ - b.y_));
}

void merge_sort(int l, int mid, int r) {
int i = l, j = l, k = mid + 1;
while (j <= mid && k <= r)
tt[i++] = (arr[j].y_ < arr[k].y_) ? arr[j++] : arr[k++];
while (j <= mid) tt[i++] = arr[j++];
while (k <= r) tt[i++] = arr[k++];
for (i = l; i <= r; ++i) arr[i] = tt[i];
}

double divide(int l, int r) {
if (l == r) return INT_MAX;
// pivot保存原本的中间x值
int cnt = 0, mid = (l + r) / 2, pivot = arr[mid].x_;
double mi = min(divide(l, mid), divide(mid + 1, r));
// y轴归并排序
merge_sort(l, mid, r);
for (int i = l; i <= r; ++i) {
if (fabs(arr[i].x_ - pivot) < mi) t[++cnt] = arr[i];
}
for (int i = 1; i < cnt; ++i) {
for (int j = i + 1; j <= cnt && fabs(t[i].y_ - t[j].y_) < mi; ++j) {
mi = min(mi, calc(t[i], t[j]));
}
}
return mi;
}

int main() {
cin >> n;
for (int i = 1; i <= n; ++i) cin >> arr[i].x_ >> arr[i].y_;
sort(arr + 1, arr + 1 + n, [](node &a, node &b) -> bool {
return a.x_ < b.x_ || (a.x_ == b.x_ && a.y_ < b.y_);
});
printf("%.4lf", divide(1, n));
}