730. 统计不同回文子序列

考点

  • 区间DP
  • 回文序列

思路


1. 问题建模

给定一个只包含 'a' ~ 'd' 的字符串 s,要求统计其中不同的回文子序列的数量,答案对 \[ \text{mod} = 10^9 + 7 \] 取模。

特点:

  • 子序列,不是子串,可以跳着选。
  • 要求不同,不能把同一个回文子序列重复计数。
  • 回文结构天然是「左右对称」,很适合用区间 DP 从短区间推到长区间。

因此,可以把问题建模成「对每个子串区间 \([i, j]\),求该区间内所有不同回文子序列的数量」。


2. 状态定义

令字符串长度为 \(n\),采用 1-based 下标。

定义状态: \[ f[i][j] = \text{区间 } s[i..j] \text{ 中,不同回文子序列的个数} \] 基本情况:

  • \(i = j\) 时,区间里只有一个字符,本身就是一个长度为 1 的回文: \[ f[i][i] = 1 \]

目标答案是: \[ f[1][n] \]


3. 辅助数组 nxt / pre 的含义

为了在转移时判定「区间内部某字符出现的次数」,需要提前预处理每个位置相同字符的前后出现位置。

在代码中:

  • nxt[i] 表示:位置 i 右边最近的一个、且与 s[i] 相同的字符位置。 若右边没有,则为 0
  • pre[i] 表示:位置 i 左边最近的一个、且与 s[i] 相同的字符位置。 若左边没有,则为 0

预处理做法:

  • 从右往左扫,更新 nxt
  • 清空计数,再从左往右扫,更新 pre

这样,在区间 DP 时,可以通过 nxt[i]pre[j] 快速定位「与两端字符相同的内部位置」。


4. 区间 DP 的遍历顺序

外层按区间长度 len = 2 ... n 枚举,内层枚举左端点 i

  • 区间右端点为 j = i + len - 1
  • 对每个 [i, j],根据首尾字符是否相同,做不同的转移。

5. 情况一:首尾字符不同

s[i - 1] != s[j - 1](注意代码是 0-based 访问),则区间 [i, j] 内的回文子序列来自两部分:

  1. 完全落在 [i+1, j] 内;
  2. 完全落在 [i, j-1] 内;

这两部分的交集,正是 [i+1, j-1] 内的回文子序列。 因此用容斥即可: \[ f[i][j] = f[i+1][j] + f[i][j-1] - f[i+1][j-1] \] 代码里再加上 + mod 然后 % mod,避免负数。


6. 情况二:首尾字符相同(核心)

s[i - 1] == s[j - 1] 时,记这个公共字符为 c

这一类区间是本题的核心,因为可以用 c 把中间的回文子序列「包一层」,产生新的回文。

6.1 先看「统一的基础贡献」:乘 2 从哪里来?

先不去区分内部有几个 c,只盯着「中间区间」:

  • 中间区间为 [i+1, j-1]
  • 其中的回文子序列数量为 f[i+1][j-1]

对于中间的每一个回文子序列 p,在整个区间 [i, j] 中,会有两种形式出现:

  1. 不用外层的 c:保留为 p
  2. 用外层的 c 包一层:变成 c + p + c

这两种一定是不同的回文(长度也不同),因此:

中间每一个回文子序列,都可以派生出两种回文形式。

于是有一个统一的「基础翻倍」贡献: \[ \text{基础贡献} = 2 \times f[i+1][j-1] \] 代码里记为:

1
long long m = 2 * f[i + 1][j - 1] % mod;

后面所有分类讨论,都是在这个 m 的基础上再 +2 / +1 / - something

6.2 如何判断内部有几个字符 c

现在只关心字符 c = s[i] = s[j] 在区间 (i, j) 内的分布情况。

利用预处理的两个数组:

  • nx = nxt[i]:从 i 向右的最近同字符位置;
  • pr = pre[j]:从 j 向左的最近同字符位置。

然后在区间内部做约束:

1
2
int l = nx <= j && nx >= i ? nx : 0;
int r = pr <= j && pr >= i ? pr : 0;

可以理解为:

  • nx 落在 [i, j] 内,就认为是区间内部和 s[i] 相同的一个位置,把它记作 l
  • pr 落在 [i, j] 内,就认为是区间内部和 s[j] 相同的一个位置,把它记作 r
  • 如果不在区间内,则视为不存在,用 0 表示。

这样,lr 可以用于大致刻画「内部有几个 c」:

  • l == 0 && r == 0l > r:内部没有 c
  • 某些组合可判断为「内部恰好 1 个 c」;
  • 其余情况可视为「内部至少 2 个 c」。

下面就是对这三种情况的详细转移。


7. 三种子情况的分类与转移

s[i-1] == s[j-1] 的前提下,代码中:

1
long long m = 2 * f[i + 1][j - 1] % mod;

是统一的基础量,然后根据 lr 分情况。

7.1 情况 A:内部没有同字符(c 未在中间出现)

判断条件:

1
if ( (l == 0 && r == 0) || l > r )

可以理解为:在 (i, j) 内没有有效的与两端相同的字符。

此时,除了前面说的「中间回文翻倍」外,外层字符 c 自己还能贡献两种新的回文:

  • 单独的 "c"
  • 两端组合成的 "cc"

这两个回文在更小的区间中都没有出现过,因此是纯新增的: \[ f[i][j] = 2 \cdot f[i+1][j-1] + 2 \] 代码对应:

1
f[i][j] = (m + 2) % mod;

7.2 情况 B:内部恰好一个同字符

判断条件:

1
else if ( (l == 0 && r) || (l && r == 0) || (l == r) )

直观理解:

  • l == rnxt[i]pre[j] 在区间内指向同一个位置,即内部只找到一个 c
  • (l == 0 && r)(l && r == 0):从左(或从右)只能找到一个有效位置,从另一个方向找不到第二个位置,也视为内部只有一个 c

在「内部恰好一个 c」时,会发生如下现象:

  • 单字符 "c" 这个回文,其实在内部那个 c 的位置已经统计过一次;
  • 两端合成的 "cc" 是新的。

也就是说,相比「没有内部 c」的情况,只多出一个新回文,而不是两个: \[ f[i][j] = 2 \cdot f[i+1][j-1] + 1 \] 代码对应:

1
f[i][j] = (m + 1) % mod;

7.3 情况 C:内部至少两个同字符

其余情况统一归到 else

1
2
3
else {
f[i][j] = (m - f[l + 1][r - 1] + mod) % mod;
}

含义是:如果在 (i, j) 内至少找到了两个不同位置的字符 c,那么在区间 [l, r] 的内部,已经存在一批被 c 包裹的回文形态,这些回文在当前区间再次用 c 包裹时会被重复计数。

具体来说:

  • 中间所有回文翻倍给出 2 * f[i+1][j-1]
  • 其中有一部分回文来源于子区间 [l+1, r-1],并且在这里被重复包裹了一遍;
  • 所以要减去 f[l+1][r-1],把那部分重复计数的回文去掉。

于是得到: \[ f[i][j] = 2 \cdot f[i+1][j-1] - f[l+1][r-1] \] 代码中配合 + mod% mod,保证结果非负。


8. 复杂度分析

  • 区间 DP 共有 \(O(n^2)\) 个状态;

  • 每个状态的转移为 \(O(1)\)

  • 因此总时间复杂度为: \[ O(n^2) \]

  • DP 数组为 f[maxn][maxn],空间复杂度为: \[ O(n^2) \]

\(n \le 10^3\) 的限制下,该算法可以通过。


9. AC代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class Solution {
public:
static const int maxn = 1e3 + 5, mod = 1e9 + 7;

int countPalindromicSubsequences(string s) {
// f[i][j]: s[i..j] 中不同回文子序列的个数(1-based)
long long f[maxn][maxn];

int n = s.size();

// nxt[i]: 位置 i 右侧最近的同字符位置(否则为 0)
int nxt[maxn] = {};
// pre[i]: 位置 i 左侧最近的同字符位置(否则为 0)
int pre[maxn] = {};

// idx[c]: 最近一次出现字符 c 的位置(c ∈ [0,3],对应 'a'~'d')
int idx[4] = {};

// 预处理 nxt 数组:从右往左扫
for (int i = n, d; i >= 1; --i) {
d = s[i - 1] - 'a';
nxt[i] = idx[d]; // 当前 i 的下一个同字符位置
idx[d] = i; // 更新该字符最近出现位置
}

// 预处理 pre 数组:从左往右扫
memset(idx, 0, sizeof(idx));
for (int i = 1, d; i <= n; ++i) {
d = s[i - 1] - 'a';
pre[i] = idx[d]; // 当前 i 的前一个同字符位置
idx[d] = i; // 更新该字符最近出现位置
}

// 区间长度为 1 时,单个字符本身就是一个回文
for (int i = 1; i <= n; ++i)
f[i][i] = 1;

int nx, pr; // 临时变量,存 nxt[i] / pre[j]

// 按区间长度枚举
for (int len = 2; len <= n; ++len) {
// 枚举左端点 i
for (int i = 1; i + len - 1 <= n; ++i) {
int j = i + len - 1; // 右端点 j

if (s[i - 1] != s[j - 1]) {
// 首尾字符不同:用容斥
f[i][j] =
(f[i + 1][j] + f[i][j - 1] - f[i + 1][j - 1] + mod) % mod;
} else {
// 首尾字符相同:设公共字符为 c = s[i] = s[j]
nx = nxt[i]; // 从左端 i 往右的最近同字符位置
pr = pre[j]; // 从右端 j 往左的最近同字符位置

// 将 nx, pr 限制在当前区间 [i, j] 内,否则视为不存在
int l = (nx <= j && nx >= i ? nx : 0);
int r = (pr <= j && pr >= i ? pr : 0);

// 中间区间 [i+1, j-1] 的回文翻倍,形成基础贡献
long long m = 2 * f[i + 1][j - 1] % mod;

if ((l == 0 && r == 0) || l > r) {
// 情况 A:中间没有同字符 c
// 新增 "c" 和 "cc" 两个回文
f[i][j] = (m + 2) % mod;
} else if ((l == 0 && r) || (l && r == 0) || (l == r)) {
// 情况 B:中间恰好出现一次 c
// "c" 已出现过一次,只新增 "cc" 一个回文
f[i][j] = (m + 1) % mod;
} else {
// 情况 C:中间至少出现两次 c
// [l+1, r-1] 内的某些被 c 包裹的回文在当前区间会重复,需要减掉
f[i][j] = (m - f[l + 1][r - 1] + mod) % mod;
}
}
}
}

return f[1][n];
}
};