446. 等差数列划分 II - 子序列

考点

  • 线性DP

思路

1. 问题描述

给定一个整数数组 \[ \texttt{nums} = [a_0, a_1, \dots, a_{n-1}] \] 需要统计其中 长度至少为 3 的等差子序列 的个数。

注意:

  • 子序列不要求连续;
  • 等差数列要求相邻元素差值相等;
  • \(n \le 1000\),元素范围为 32 位有符号整数。

2. 问题建模思路

2.1 为什么不能直接枚举?

若直接枚举所有子序列:

  • 子序列数量为 \(2^n\),完全不可行;
  • 即使只枚举三元组或更长序列,也无法有效保证等差性。

因此需要利用 动态规划,逐步“构造”等差子序列。


2.2 建模的关键观察

设某个等差子序列的最后两个元素下标为 \(j < i\),差值为 \[ d = a_i - a_j \] 若此前存在一个\(a_j\) 结尾、差值同为 \(d\) 的等差子序列,那么把 \(a_i\) 接在其末尾后,仍然是等差子序列。

因此,等差子序列可以通过“按差值分类的状态转移”来构造。


3. 状态设计

3.1 状态定义

定义状态: \[ f[i][d] = \text{以 } a_i \text{ 结尾,公差为 } d \text{ 的等差子序列个数(长度 ≥ 2)} \] 说明:

  • 状态只统计 长度 ≥ 2 的等差子序列;
  • 长度为 2 的序列是后续构造合法答案(长度 ≥ 3)的“种子”;
  • 最终答案并不直接来自 \(f\),而是在状态转移过程中累加。

实现上,使用:

1
vector<unordered_map<long long, long long>> f;

其中:

  • 外层下标是结尾位置 i
  • 内层 unordered_map 的 key 是差值 d,value 是计数。

3.2 为什么不直接存长度 ≥ 3?

因为:

  • 长度为 3 的等差子序列必然由“长度为 2 的序列 + 一个新元素”得到;
  • 若只存长度 ≥ 3,会丢失扩展所需的中间状态;
  • 统一存“长度 ≥ 2”,转移时自然区分是否计入答案。

4. 状态转移方程

4.1 枚举转移来源

固定右端点 \(i\),枚举左端点 \(j < i\)\[ d = a_i - a_j \] 若此前存在状态 \(f[j][d]\),表示:

  • \(f[j][d]\) 个等差子序列(长度 ≥ 2)
  • 它们都可以接上 \(a_i\)

4.2 转移逻辑

  1. 答案累加

    原来以 \(a_j\) 结尾、长度 ≥ 2 的序列, 接上 \(a_i\) 后长度 ≥ 3,全部是合法答案\[ \text{ans} \;+=\; f[j][d] \]

  2. 状态更新

    新的以 \(a_i\) 结尾、公差为 \(d\) 的序列包括:

    • \(f[j][d]\) 扩展而来的 \(f[j][d]\) 个;
    • 新生成的长度为 2 的序列 \((a_j, a_i)\) 一个。

    因此: \[ f[i][d] \;+=\; f[j][d] + 1 \]


4.3 完整状态转移方程

\[ \begin{aligned} \text{cnt} &= f[j][d] \quad (\text{若不存在则为 } 0) \\ \text{ans} &+= \text{cnt} \\ f[i][d] &+= \text{cnt} + 1 \end{aligned} \]


5. 边界与实现细节

5.1 差值溢出问题

数组元素是 int,但差值可能超过 int 范围,因此必须使用: \[ d = (long\ long)a_i - (long\ long)a_j \] 否则会触发 有符号整数溢出(Undefined Behavior)


5.2 时间与空间复杂度

  • 时间复杂度: \[ O(n^2) \] 枚举所有 \((j, i)\) 对。

  • 空间复杂度: \[ O(n^2) \] 最坏情况下,每个位置可能存 \(O(n)\) 个不同差值。


6. AC代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Solution {
public:
int numberOfArithmeticSlices(vector<int>& nums) {
int n = nums.size();
long long res = 0;

// 长度小于 3,不可能存在合法等差子序列
if (n < 3)
return res;

// f[i][d]: 以 nums[i] 结尾、公差为 d 的等差子序列个数(长度 >= 2)
vector<unordered_map<long long, long long>> f(n);

long long d, cnt;

// 枚举右端点 i
for (int i = 0; i < n; ++i) {
// 枚举左端点 j
for (int j = i - 1; j >= 0; --j) {
// 必须在 long long 范围内计算差值,避免 int 溢出
d = 1LL * nums[i] - nums[j];
cnt = 0;

// 查找是否存在以 j 结尾、差值为 d 的等差子序列
auto it = f[j].find(d);
if (it != f[j].end())
cnt = it->second;

// 扩展后的序列长度 >= 3,计入答案
res += cnt;

// 更新以 i 结尾的状态:
// 1) 从 f[j][d] 扩展来的 cnt 个
// 2) 新生成的长度为 2 的序列 (nums[j], nums[i])
f[i][d] += cnt + 1;
}
}

return res;
}
};