后缀数组
后缀数组其实概念很好理解,就是给出一个字符串,长度是n,对它所有的n个后缀编号从1到n进行排序,排序后,最小的那个后缀的编号假设是m1,那么sa[1] = m1
,类似地,第二小的是m2的话,sa[2] = m2
,sa这个数组就是我们所需要的后缀数组。根据这个,我们可以直接用sort算出sa,以下为最简单的实现
struct SA_simple
{
vector<int> sa;
int s_size;
const char* p_s;
int size() const
{
return s_size;
}
static bool cmp(const char* x, const char* y)
{
return strcmp(x, y) < 0;
}
void init(char * str)
{
int n = strlen(str);
s_size = n;
p_s = str - 1;
sa.resize(n + 1);
vector< const char* > rp;
rp.resize(n + 1);
for (int i = 1; i <= n; ++i)
{
rp[i] = p_s + i;
}
sort(rp.begin() + 1, rp.end(), cmp);
for (int i = 1; i <= n; ++i)
{
sa[i] = rp[i] - p_s;
}
}
};
这个实现的时间复杂度 $O(n^2logn)$
要注意的一点是下标从1开始。有了这个,可以做点什么呢?例如给你一个串p,求出p在主串s中出现了多少次。那么在有了sa的情况下,因为sa是有序的,问题就变成了二分搜索,分别用lower_bound
和upper_bound
通过sa搜索p,两个相减便得出现次数。
rank 数组
光有sa其实还不够用,我们还需要rank数组,rank[m]
的值是p的话,那么表示字符串中编号m的后缀,它的排名是p,即与sa数组是互逆,所以我们可以得到 sa[rank[i]] == rank[sa[i]] == i
,也就是说通过rank,可以快速判断某两个后缀的大小关系。
height 数组
height[i]
的值表示的是,sa[i-1]
与sa[i]
这两个后缀的相同前缀长度,特别地,height[1] == 0
,求解height需要用到rank数组和sa数组,以及如下引理
$$height[rank[i]]\ge height[rank[i-1]]-1$$
通过以上引理直接暴力实现即可,复杂度 $O(n)$ ,这里不做证明。
三个数组的完整实现如下
struct SA_simple
{
vector<int> sa, rk, ht;
protected:
int s_size;
const char* p_s;
public:
int size() const
{
return s_size;
}
static bool cmp(const char* x, const char* y)
{
return strcmp(x, y) < 0;
}
void init(char * str, bool h = true)
{
int n = strlen(str);
s_size = n;
p_s = str - 1;
sa.resize(n + 1);
rk.resize(n + 1);
vector<const char*> rp;
rp.resize(n + 1);
for (int i = 1; i <= n; ++i)
{
rp[i] = p_s + i;
}
sort(rp.begin() + 1, rp.end(), cmp);
for (int i = 1; i <= n; ++i)
{
sa[i] = rp[i] - p_s;
rk[sa[i]] = i;
}
if (h) create_height();
}
void create_height()
{
ht.resize(s_size + 1);
for (int i = 1, k = 0; i <= s_size; ++i)
{
if (k) --k;
while (p_s[i + k] == p_s[sa[rk[i] - 1] + k])
++k;
ht[rk[i]] = k;
}
}
};
应用1 可重叠最长重复子串
这个题目在本博客讲kmp部分已经有介绍,例如eabcaefabcabc
,最长重复子串是abca
,长度是4,这里介绍用后缀数组的解法。其实所谓的最长重复子串,就是找到两个后缀,让它们的公共前缀最长,那这就简单了,我们只要在height数组里找最大值就可以了,查找时间 $O(n)$ 。
应用2 不同子串的个数
来看这道题 SPOJ-DISUBSTR ,说的是统计一个字符串里有多少不同的子串。
这里我们就需要用到height数组,由于它表示的正是和前一个后缀的相同前缀长度,那么我们对任意的后缀sa[i]
,取这个后缀的长度,即len(s)-sa[i]
,减去height[i]
再加上1,即表示sa[i]
这个后缀有多少个前缀与sa[i-1]
不相同,所以我们累加即可。核心代码也就这么几行
int sum_h = 0;
for (int i = 1; i <= sa.size(); ++i)
{
sum_h += sa.size() - sa.sa[i] - sa.ht[i] + 1;
}
printf("%d\n", sum_h);
也许有人会发现问题了,这个题直接用hash实现,才 $O(n^2)$ 的复杂度,这个后缀数组的实现,光是生成就 $O(n^2logn)$ ,不是还更慢吗?单看时间复杂度的确是这样,但事实上后缀数组可以0ms通过,hash实现约400ms左右。
优化
直接排序的后缀数组确实过于暴力了,虽然不少题目已经足够AC,但我们还有更好的,这里简要介绍倍增法。假设对字符串”ababaabb”求后缀数组,那么先对每一个字符做排序,计算出它们的rank,注意相同串的rank结果要相同,结果在下表的”排序1”,然后我们对每个i
和i+1
在”排序1”上的rank组合起来,这个组合的key再做排序,如下表
下标 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
---|---|---|---|---|---|---|---|---|
原 | a | b | a | b | a | a | b | b |
排序1 | 1 | 2 | 1 | 2 | 1 | 1 | 2 | 2 |
组合1 | 1 2 | 2 1 | 1 2 | 2 1 | 1 1 | 1 2 | 2 2 | 2 0 |
排序2 | 2 | 4 | 2 | 4 | 1 | 2 | 5 | 3 |
事实上这样得到了所有后缀中,前缀长度为2的排名,接下来,我们步长翻倍,对每个i
和i+2
在”排序2”上的rank组合起来再排序
下标 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
---|---|---|---|---|---|---|---|---|
排序2 | 2 | 4 | 2 | 4 | 1 | 2 | 5 | 3 |
组合2 | 2 2 | 4 4 | 2 1 | 4 2 | 1 5 | 2 3 | 5 0 | 3 0 |
排序3 | 3 | 7 | 2 | 6 | 1 | 4 | 8 | 5 |
事实上这样得到了所有后缀中,前缀长度为4的排名,接下来,我们步长翻倍,对每个i
和i+4
在”排序3”上的rank组合起来再排序
下标 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
---|---|---|---|---|---|---|---|---|
排序3 | 3 | 7 | 2 | 6 | 1 | 4 | 8 | 5 |
组合3 | 3 1 | 7 4 | 2 8 | 6 5 | 1 0 | 4 0 | 8 0 | 5 0 |
排序4 | 3 | 7 | 2 | 6 | 1 | 4 | 8 | 5 |
至此,再下一轮的步长是8,已经大于等于字符串长度的时候,rank数组便计算完成了。以下是使用此思路的实现代码
struct SA_2_sort
{
vector<int> sa, ht;
int *rk;
protected:
vector<int> rk1, rk2;
int s_size;
int *p_rk, *o_rk;
const char* p_s;
struct SA_2_sort_cmp
{
int *rk, w;
SA_2_sort_cmp(int *_rk, int _w) :rk(_rk), w(_w) {}
bool operator()(int x, int y) const
{
return rk[x] == rk[y] ? rk[x + w] < rk[y + w] : rk[x] < rk[y];
}
};
public:
bool cmp(int x, int y, int w)
{
return o_rk[x] == o_rk[y] && o_rk[x + w] == o_rk[y + w];
}
int size() const
{
return s_size;
}
void init(char * str, bool h = true)
{
int n = strlen(str);
s_size = n;
p_s = str - 1;
sa.resize(n + 1);
rk1.clear(); rk1.resize(n * 2 + 2);
rk2.clear(); rk2.resize(n * 2 + 2);
p_rk = &*rk1.begin();
o_rk = &*rk2.begin();
for (int i = 1; i <= n; ++i) p_rk[i] = p_s[i];
for (int w = 1, i, p; w < n; w <<= 1)
{
// init sa
for (int i = 1; i <= n; ++i) sa[i] = i;
sort(sa.begin() + 1, sa.end(), SA_2_sort_cmp(p_rk, w));
// write new rank
for (std::swap(p_rk, o_rk), p = 0, i = 1; i <= n; ++i)
p_rk[sa[i]] = cmp(sa[i], sa[i - 1], w) ? p : ++p;
}
rk = p_rk;
if (n == 1) sa[1] = rk[1] = 1;
if (h) create_height();
}
void create_height()
{
int n = s_size;
ht.resize(n + 1);
for (int i = 1, k = 0; i <= n; ++i)
{
if (k) --k;
while (p_s[i + k] == p_s[sa[p_rk[i] - 1] + k])
++k;
ht[p_rk[i]] = k;
}
}
};
以上实现的时间复杂度是 $O(nlog^2n)$ ,如果你想要更快,那就用 $O(n)$ 的计数排序吧,便可把整体时间复杂度下降到 $O(nlogn)$
再优化
绝大多数情况下,使用以上方法 $O(nlogn)$ 复杂度已经够用了,但如果你是一个更有追求的人,可以继续学习 $O(n)$ 复杂度建立后缀数组的办法,名字叫做SA-IS
和DC3
,你可以通过搜索以上两个名字得到更具体的介绍,本文就只介绍到这里。
$O(nlogn)$ 模板
struct SA_2
{
vector<int> sa, ht;
int *rk;
protected:
vector<int> rk1, rk2;
int s_size;
int *p_rk, *o_rk;
const char* p_s;
public:
bool cmp(int x, int y, int w)
{
return o_rk[x] == o_rk[y] && o_rk[x + w] == o_rk[y + w];
}
int size() const
{
return s_size;
}
void init(char * str, bool h = true)
{
int n = strlen(str);
s_size = n;
p_s = str - 1;
int cnt_size = max(256, n) + 1;
vector<int> vid, vpx, vcnt;
sa.resize(n + 1);
rk1.clear(); rk1.resize(n * 2 + 2);
rk2.clear(); rk2.resize(n * 2 + 2);
vid.resize(n + 1); vpx.resize(n + 1);
vcnt.resize(cnt_size);
int* id = &*vid.begin();
int* px = &*vpx.begin();
int* cnt = &*vcnt.begin();
p_rk = &*rk1.begin();
o_rk = &*rk2.begin();
int m = 128, p = 0;
for (int i = 1; i <= n; ++i) ++cnt[p_rk[i] = p_s[i]];
for (int i = 1; i <= m; ++i) cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; --i) sa[cnt[p_rk[i]]--] = i;
for (int w = 1, i; w < n; w <<= 1, m = p)
{
// init id
for (p = 0, i = n; i > n - w; --i)
id[++p] = i;
for (int i = 1; i <= n; ++i)
if (sa[i] > w) id[++p] = sa[i] - w;
sort(cnt, id, px, n, m);
// write new rank
for (std::swap(p_rk, o_rk), p = 0, i = 1; i <= n; ++i)
p_rk[sa[i]] = cmp(sa[i], sa[i - 1], w) ? p : ++p;
}
rk = p_rk;
if (n == 1) sa[1] = rk[1] = 1;
if (h) create_height();
}
void sort(int* cnt, int* id, int* px, int n, int m)
{
memset(cnt, 0, sizeof(int) * (m + 1));
for (int i = 1; i <= n; ++i)
++cnt[px[i] = p_rk[id[i]]];
for (int i = 1; i <= m; ++i)
cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; --i)
sa[cnt[px[i]]--] = id[i];
}
void create_height()
{
int n = s_size;
ht.resize(n + 1);
for (int i = 1, k = 0; i <= n; ++i)
{
if (k) --k;
while (p_s[i + k] == p_s[sa[p_rk[i] - 1] + k])
++k;
ht[p_rk[i]] = k;
}
}
};