考点
题解
见思路
思路
分治的经典教学题,OIWiki已经讲得很好了,见链接;其中的复杂度证明部分一定要理解!
下面讲一下编程细节:
sort实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
| #include <bits/stdc++.h> using namespace std; const int maxn = 2e5 + 50; int n;
struct node { double x_, y_; } arr[maxn], t[maxn];
bool cmpx(node &a, node &b) { return a.x_ < b.x_ || (a.x_ == b.x_ && a.y_ < b.y_); }
bool cmpy(node &a, node &b) { return a.y_ < b.y_ || (a.y_ == b.y_ && a.x_ < b.x_); }
double calc(node &a, node &b) { return sqrt((a.x_ - b.x_) * (a.x_ - b.x_) + (a.y_ - b.y_) * (a.y_ - b.y_)); }
double merge(int l, int r) { if (l == r) return INT_MAX; int cnt = 0, mid = (l + r) / 2; double mi = min(merge(l, mid), merge(mid + 1, r)); for (int i = l; i <= r; ++i) { if (fabs(arr[i].x_ - arr[mid].x_) < mi) t[++cnt] = arr[i]; } sort(t + 1, t + 1 + cnt, cmpy); for (int i = 1; i < cnt; ++i) { for (int j = i + 1; j <= cnt && fabs(t[i].y_ - t[j].y_) < mi; ++j) { mi = min(mi, calc(t[i], t[j])); } } return mi; }
int main() { cin >> n; for (int i = 1; i <= n; ++i) cin >> arr[i].x_ >> arr[i].y_; sort(arr + 1, arr + 1 + n, cmpx); printf("%.4lf", merge(1, n)); }
|
重点讲一下为什么必须要让新数组t
对y轴排序,这也是这部分代码为什么不是平方数量级的原因。
1 2 3 4 5
| for (int i = 1; i < cnt; ++i) { for (int j = i + 1; j <= cnt && fabs(t[i].y_ - t[j].y_) < mi; ++j) { mi = min(mi, calc(t[i], t[j])); } }
|
因为x、y轴方向上都满足的点个数总是常数,其余再多的点都是白瞎,你要想办法去除它们的干扰。
只要在范围内的就判断距离,不在范围内的直接进入下一次循环就好了。
为了满足这种单调性,你必须对y轴排序;乱序你还怎么筛呢?
归并实现
归并实现的时间复杂度为\(O\left( n\log n
\right)\),上面sort实现的时间复杂度为\(O\left( n\log ^2n \right)\)
因为sort实现方法每次都调用了sort函数排序,归并实现时主体就是归并,不需要额外排序。
要注意!sort实现时,我们是新开了一个数组t
来保存新的集合,并不会对原数组造成影响;
归并实现时,就是对原数组进行归并排序,原先的中值mid
位置就乱了!
所以在排序前,一定要先保存好原始的中值:
1 2 3 4 5
| pivot = arr[mid].x_ ··· for (int i = l; i <= r; ++i) { if (fabs(arr[i].x_ - pivot) < mi) t[++cnt] = arr[i]; }
|
这样才不会影响后续对x
轴方向集合的判断!
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
| #include <bits/stdc++.h> using namespace std; const int maxn = 2e5 + 50; int n;
struct node { double x_, y_; } arr[maxn], t[maxn], tt[maxn];
double calc(node &a, node &b) { return sqrt((a.x_ - b.x_) * (a.x_ - b.x_) + (a.y_ - b.y_) * (a.y_ - b.y_)); }
void merge_sort(int l, int mid, int r) { int i = l, j = l, k = mid + 1; while (j <= mid && k <= r) tt[i++] = (arr[j].y_ < arr[k].y_) ? arr[j++] : arr[k++]; while (j <= mid) tt[i++] = arr[j++]; while (k <= r) tt[i++] = arr[k++]; for (i = l; i <= r; ++i) arr[i] = tt[i]; }
double divide(int l, int r) { if (l == r) return INT_MAX; int cnt = 0, mid = (l + r) / 2, pivot = arr[mid].x_; double mi = min(divide(l, mid), divide(mid + 1, r)); merge_sort(l, mid, r); for (int i = l; i <= r; ++i) { if (fabs(arr[i].x_ - pivot) < mi) t[++cnt] = arr[i]; } for (int i = 1; i < cnt; ++i) { for (int j = i + 1; j <= cnt && fabs(t[i].y_ - t[j].y_) < mi; ++j) { mi = min(mi, calc(t[i], t[j])); } } return mi; }
int main() { cin >> n; for (int i = 1; i <= n; ++i) cin >> arr[i].x_ >> arr[i].y_; sort(arr + 1, arr + 1 + n, [](node &a, node &b) -> bool { return a.x_ < b.x_ || (a.x_ == b.x_ && a.y_ < b.y_); }); printf("%.4lf", divide(1, n)); }
|