P5094. MooFest G

考点

  • cdq分治
  • 树状数组

题解

见思路

思路

树状数组

先对v排序。

这样一来,当前元素i与其之前的所有元素j相比较,都会取vi

若xi大于xj,取xi - xj;若xi小于等于xj,取xj - xi;所以要分类讨论不同ji的贡献。

假设xj中小于等于xi的个数为cnt1,大于xi的个数为cnt2;

xj中小于等于xi的数字之和为sum1,大于xi的数字之和为sum2

显然,ji的贡献有如下等式:

  • xj中小于等于xi的部分

    cnt1 * xi - sum1

  • xj中大于xi的部分

    sum2 - cnt2 * xi

两部分相加即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 5e4 + 50;
// sum:统计和的树状数组;cnt:统计个数的树状数组
ll n, ans, sum[maxn], cnt[maxn];

int lowbit(int x) { return x & -x; }

void add(ll bit[], int x, int v) {
while (x <= maxn) bit[x] += v, x += lowbit(x);
}

ll query(ll bit[], int x) {
ll res = 0;
while (x > 0) res += bit[x], x -= lowbit(x);
return res;
}

struct node {
int v_, x_;
bool operator<(node &b) { return v_ < b.v_ || (v_ == b.v_ && x_ < b.x_); }
} a[maxn];

int main() {
cin >> n;
for (int i = 1; i <= n; ++i) cin >> a[i].v_ >> a[i].x_;
sort(a + 1, a + 1 + n);
// tot不能改成query(sum, maxn),会TLE
ll tot = 0;
for (int i = 1; i <= n; ++i) {
// cnt1:小于等于我的个数;cnt2:大于我的个数
ll cnt1 = query(cnt, a[i].x_), cnt2 = i - 1 - cnt1;
// sum1:小于等于我的数字之和,sum2:大于我的数字之和
ll sum1 = query(sum, a[i].x_), sum2 = tot - sum1;
ans += a[i].v_ * (cnt1 * a[i].x_ - sum1);
ans += a[i].v_ * (sum2 - cnt2 * a[i].x_);
tot += a[i].x_;
add(sum, a[i].x_, a[i].x_), add(cnt, a[i].x_, 1);
}
cout << ans;
return 0;
}

cdq分治

上述思路可以发现,实际上就是将数组拆成两半,计算左半部分对右半部分的贡献,这正是cdq分治的模板题。

先对v排序,左半部分的v肯定小于等于右半部分的v,然后左右部分各执行cdq分治。

左右部分合并时,先各自对x排序,令左半部分游标为j,右半部分游标为k

每次j找到第一个x值大于k时停止,此时[l, j - 1]区间内的任意x值必小于等于k的x值,该区间和即为sum1;

sum2 = 左半部分整体和 - sum1,左半部分整体和可以直接扫一遍[l, mid]累加得到,也才线性复杂度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 5e4 + 50;
ll n, ans;

struct node {
int v_, x_;
} a[maxn];

bool cmpv(node &a, node &b) {
return a.v_ < b.v_ || (a.v_ == b.v_ && a.x_ < b.x_);
}

bool cmpx(node &a, node &b) {
return a.x_ < b.x_ || (a.x_ == b.x_ && a.v_ < b.v_);
}

void cdq(int l, int r) {
if (l == r) return;
int mid = (l + r) / 2;
cdq(l, mid), cdq(mid + 1, r);
// sum1:左半部分小于等于xk的总和;sum2:左半部分大于xk的总和
ll sum1 = 0, sum2 = 0;
// cnt1:左半部分小于等于xk的个数;cnt2:左半部分大于xk的个数
int cnt1 = 0, cnt2 = 0;
// j:左半部分游标;k:右半部分游标
int j = l, k = mid + 1;
for (int i = l; i <= mid; ++i) sum2 += a[i].x_;
while (j <= mid && k <= r) {
while (j <= mid && a[j].x_ <= a[k].x_)
sum1 += a[j].x_, sum2 -= a[j].x_, ++j;
cnt1 = j - l, cnt2 = mid - j + 1;
ans += a[k].v_ * (a[k].x_ * cnt1 - sum1);
ans += a[k].v_ * (sum2 - a[k].x_ * cnt2);
++k;
}
while (k <= r) ans += a[k].v_ * ((mid - l + 1) * a[k].x_ - sum1), ++k;
sort(a + l, a + r + 1, cmpx);
}

int main() {
cin >> n;
for (int i = 1; i <= n; ++i) cin >> a[i].v_ >> a[i].x_;
sort(a + 1, a + 1 + n, cmpv);
cdq(1, n);
cout << ans;
return 0;
}