410. 分割数组的最大值

考点

  • 划分DP
  • 贪心
  • 二分

思路

3599. 划分数组得到最小 XOR的姐妹题,两道题建议同时对比学习

划分DP

令:

  • \(a_1, a_2, \dots, a_n\) 为原数组;
  • \(f[j][i]\):把前 \(i\) 个数 \(a_1..a_i\) 划分成 恰好 \(j\) 时, 所有段和中的最大值的 最小可能值

记区间和: \[ \text{sum}(l, r) = a_l + a_{l+1} + \dots + a_r \] 转移写成数学形式就是: \[ f[j][r] = \min_{j \le l \le r} \max\bigl(f[j-1][l-1],\ \text{sum}(l,r)\bigr) \] 为什么只用f[j-1][l-1]呢?道理很简单

假设我当前子数组的和为5,而前面子数组的最大和分别有4、6

5与4取max得到的是5,5与6取max得到的是6,那么答案应该选择5才对

这就意味着,我们只需要保存前面子数组最大和的最小值即可,如果该最小值小于等于5,那么我们就会取5

如果该最小值大于5,那么我们就会取该最小值

可以得到如下AC代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class Solution {
public:
static const int inf = 0x7f7f7f7f;
int splitArray(vector<int>& nums, int k) {
int n = nums.size();
int f[1005][1005];
memset(f, 0x7f, sizeof(f));
f[0][0] = 0;
for (int j = 1; j <= k; ++j) {
for (int r = j; r <= n; ++r) {
int s = 0;
for (int l = r; l >= j; --l) {
s += nums[l - 1];
if (f[j - 1][l - 1] != inf)
f[j][r] = min(f[j][r], max(f[j - 1][l - 1], s));
}
}
}
return f[k][n];
}
};

本题不同于3599. 划分数组得到最小 XOR,用二分更优

二分

设计

我们关心的答案是:

把数组分成 k 段后,所有段的「段和中的最大值」最小是多少?

设这个值为 X。有个关键单调性:

  • 如果我能用 不超过 k 段,把数组分成若干段,并且每一段的和都 ≤ X, 那么对于任何 Y >= X,也一定能做到(约束变宽松了)。
  • 如果连 X 都做不到(无论怎么切,总会有一段和 > X), 那么对于任何 Y < X,更不可能做到。

所以「能否在 ≤k 段内,使每段和 ≤ mid」这个条件对 mid单调的,可以二分。


判定

判定函数:给定 mid,能否在 ≤k 段内满足「每段和 ≤ mid」?

贪心做法:

  1. 从左到右扫数组,用 cur 记录当前这一段的和,用 cnt 记录段数(初始 1 段)。
  2. 每来一个 x
    • cur + x <= mid:继续放在当前段:cur += x
    • 否则:开新的一段:++cnt; cur = x
  3. 扫完后,看 cnt 是否 ≤ k
    • ≤ k:说明在约束 mid 下能完成拆分 → 返回 true
    • 否则说明这个 mid 太小,段数不够用 → 返回 false

注意一个细节:如果存在某个 nums[i] > mid,那么无论怎么切,都会有一段至少包含它,所以这一段的和一定 > mid,直接 false。但我们二分的时候下界就设为 max(nums[i]),这样自然不会出现这种情况。


二分范围

  • 下界 L = max(nums[i]):最小也得容得下单个元素;
  • 上界 R = sum(nums[i]):最大就是不切分,整个数组一段。

[L, R] 上二分最小的可行值。

可以得到如下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Solution {
public:
const static int inf = 0x3f3f3f3f;
int splitArray(vector<int>& nums, int k) {
int l = -1, r = 0, n = nums.size();
// 二分
auto check = [&](int mid) -> bool {
int cnt = 0, s = inf, i = 1;
while (i <= n) {
if (nums[i - 1] > mid) return 0;
if (s + nums[i - 1] > mid)
++cnt, s = nums[i - 1];
else
s += nums[i - 1];
i++;
}
return cnt <= k;
};
// 二分答案
for (int i = 1; i <= n; ++i) l = max(l, nums[i - 1]), r += nums[i - 1];
while (l < r) {
int mid = (l + r) >> 1;
if (check(mid))
r = mid;
else
l = mid + 1;
}
return l;
}
};