后缀数组

后缀数组其实概念很好理解,就是给出一个字符串,长度是n,对它所有的n个后缀编号从1到n进行排序,排序后,最小的那个后缀的编号假设是m1,那么sa[1] = m1,类似地,第二小的是m2的话,sa[2] = m2,sa这个数组就是我们所需要的后缀数组。根据这个,我们可以直接用sort算出sa,以下为最简单的实现

struct SA_simple
{
    vector<int> sa;
    int s_size;
    const char* p_s;
    int size() const
    {
        return s_size;
    }
    static bool cmp(const char* x, const char* y)
    {
        return strcmp(x, y) < 0;
    }
    void init(char * str)
    {
        int n = strlen(str);
        s_size = n;
        p_s = str - 1;

        sa.resize(n + 1);
        vector< const char* > rp;
        rp.resize(n + 1);

        for (int i = 1; i <= n; ++i)
        {
            rp[i] = p_s + i;
        }
        sort(rp.begin() + 1, rp.end(), cmp);
        for (int i = 1; i <= n; ++i)
        {
            sa[i] = rp[i] - p_s;
        }
    }
};

这个实现的时间复杂度 $O(n^2logn)$

要注意的一点是下标从1开始。有了这个,可以做点什么呢?例如给你一个串p,求出p在主串s中出现了多少次。那么在有了sa的情况下,因为sa是有序的,问题就变成了二分搜索,分别用lower_boundupper_bound通过sa搜索p,两个相减便得出现次数。

rank 数组

光有sa其实还不够用,我们还需要rank数组,rank[m]的值是p的话,那么表示字符串中编号m的后缀,它的排名是p,即与sa数组是互逆,所以我们可以得到 sa[rank[i]] == rank[sa[i]] == i ,也就是说通过rank,可以快速判断某两个后缀的大小关系。

height 数组

height[i]的值表示的是,sa[i-1]sa[i]这两个后缀的相同前缀长度,特别地,height[1] == 0,求解height需要用到rank数组和sa数组,以及如下引理

$$height[rank[i]]\ge height[rank[i-1]]-1$$

通过以上引理直接暴力实现即可,复杂度 $O(n)$ ,这里不做证明。

三个数组的完整实现如下

struct SA_simple
{
    vector<int> sa, rk, ht;
protected:
    int s_size;
    const char* p_s;
public:
    int size() const
    {
        return s_size;
    }

    static bool cmp(const char* x, const char* y)
    {
        return strcmp(x, y) < 0;
    }

    void init(char * str, bool h = true)
    {
        int n = strlen(str);
        s_size = n;
        p_s = str - 1;

        sa.resize(n + 1);
        rk.resize(n + 1);
        vector<const char*> rp;
        rp.resize(n + 1);

        for (int i = 1; i <= n; ++i)
        {
            rp[i] = p_s + i;
        }
        sort(rp.begin() + 1, rp.end(), cmp);
        for (int i = 1; i <= n; ++i)
        {
            sa[i] = rp[i] - p_s;
            rk[sa[i]] = i;
        }
        if (h) create_height();
    }
    void create_height()
    {
        ht.resize(s_size + 1);
        for (int i = 1, k = 0; i <= s_size; ++i)
        {
            if (k) --k;
            while (p_s[i + k] == p_s[sa[rk[i] - 1] + k])
                ++k;
            ht[rk[i]] = k;
        }
    }
};

应用1 可重叠最长重复子串

这个题目在本博客讲kmp部分已经有介绍,例如eabcaefabcabc,最长重复子串是abca,长度是4,这里介绍用后缀数组的解法。其实所谓的最长重复子串,就是找到两个后缀,让它们的公共前缀最长,那这就简单了,我们只要在height数组里找最大值就可以了,查找时间 $O(n)$ 。

应用2 不同子串的个数

来看这道题 SPOJ-DISUBSTR ,说的是统计一个字符串里有多少不同的子串。

这里我们就需要用到height数组,由于它表示的正是和前一个后缀的相同前缀长度,那么我们对任意的后缀sa[i],取这个后缀的长度,即len(s)-sa[i],减去height[i]再加上1,即表示sa[i]这个后缀有多少个前缀与sa[i-1]不相同,所以我们累加即可。核心代码也就这么几行

int sum_h = 0;
for (int i = 1; i <= sa.size(); ++i)
{
    sum_h += sa.size() - sa.sa[i] - sa.ht[i] + 1;
}
printf("%d\n", sum_h);

也许有人会发现问题了,这个题直接用hash实现,才 $O(n^2)$ 的复杂度,这个后缀数组的实现,光是生成就 $O(n^2logn)$ ,不是还更慢吗?单看时间复杂度的确是这样,但事实上后缀数组可以0ms通过,hash实现约400ms左右。

优化

直接排序的后缀数组确实过于暴力了,虽然不少题目已经足够AC,但我们还有更好的,这里简要介绍倍增法。假设对字符串”ababaabb”求后缀数组,那么先对每一个字符做排序,计算出它们的rank,注意相同串的rank结果要相同,结果在下表的”排序1”,然后我们对每个ii+1在”排序1”上的rank组合起来,这个组合的key再做排序,如下表

下标 1 2 3 4 5 6 7 8
a b a b a a b b
排序1 1 2 1 2 1 1 2 2
组合1 1 2 2 1 1 2 2 1 1 1 1 2 2 2 2 0
排序2 2 4 2 4 1 2 5 3

事实上这样得到了所有后缀中,前缀长度为2的排名,接下来,我们步长翻倍,对每个ii+2在”排序2”上的rank组合起来再排序

下标 1 2 3 4 5 6 7 8
排序2 2 4 2 4 1 2 5 3
组合2 2 2 4 4 2 1 4 2 1 5 2 3 5 0 3 0
排序3 3 7 2 6 1 4 8 5

事实上这样得到了所有后缀中,前缀长度为4的排名,接下来,我们步长翻倍,对每个ii+4在”排序3”上的rank组合起来再排序

下标 1 2 3 4 5 6 7 8
排序3 3 7 2 6 1 4 8 5
组合3 3 1 7 4 2 8 6 5 1 0 4 0 8 0 5 0
排序4 3 7 2 6 1 4 8 5

至此,再下一轮的步长是8,已经大于等于字符串长度的时候,rank数组便计算完成了。以下是使用此思路的实现代码

struct SA_2_sort
{
    vector<int> sa, ht;
    int *rk;
protected:
    vector<int> rk1, rk2;
    int s_size;
    int *p_rk, *o_rk;
    const char* p_s;

    struct SA_2_sort_cmp
    {
        int *rk, w;
        SA_2_sort_cmp(int *_rk, int _w) :rk(_rk), w(_w) {}
        bool operator()(int x, int y) const
        {
            return rk[x] == rk[y] ? rk[x + w] < rk[y + w] : rk[x] < rk[y];
        }
    };
public:
    bool cmp(int x, int y, int w)
    {
        return o_rk[x] == o_rk[y] && o_rk[x + w] == o_rk[y + w];
    }
    int size() const
    {
        return s_size;
    }
    void init(char * str, bool h = true)
    {
        int n = strlen(str);
        s_size = n;
        p_s = str - 1;
        sa.resize(n + 1);
        rk1.clear(); rk1.resize(n * 2 + 2);
        rk2.clear(); rk2.resize(n * 2 + 2);
        p_rk = &*rk1.begin();
        o_rk = &*rk2.begin();

        for (int i = 1; i <= n; ++i) p_rk[i] = p_s[i];

        for (int w = 1, i, p; w < n; w <<= 1)
        {
            // init sa
            for (int i = 1; i <= n; ++i) sa[i] = i;

            sort(sa.begin() + 1, sa.end(), SA_2_sort_cmp(p_rk, w));

            // write new rank
            for (std::swap(p_rk, o_rk), p = 0, i = 1; i <= n; ++i)
                p_rk[sa[i]] = cmp(sa[i], sa[i - 1], w) ? p : ++p;
        }
        rk = p_rk;
        if (n == 1) sa[1] = rk[1] = 1;
        if (h) create_height();
    }
    void create_height()
    {
        int n = s_size;
        ht.resize(n + 1);
        for (int i = 1, k = 0; i <= n; ++i)
        {
            if (k) --k;
            while (p_s[i + k] == p_s[sa[p_rk[i] - 1] + k])
                ++k;
            ht[p_rk[i]] = k;
        }
    }
};

以上实现的时间复杂度是 $O(nlog^2n)$ ,如果你想要更快,那就用 $O(n)$ 的计数排序吧,便可把整体时间复杂度下降到 $O(nlogn)$

再优化

绝大多数情况下,使用以上方法 $O(nlogn)$ 复杂度已经够用了,但如果你是一个更有追求的人,可以继续学习 $O(n)$ 复杂度建立后缀数组的办法,名字叫做SA-ISDC3,你可以通过搜索以上两个名字得到更具体的介绍,本文就只介绍到这里。

$O(nlogn)$ 模板

struct SA_2
{
    vector<int> sa, ht;
    int *rk;
protected:
    vector<int> rk1, rk2;
    int s_size;
    int *p_rk, *o_rk;
    const char* p_s;
public:
    bool cmp(int x, int y, int w)
    {
        return o_rk[x] == o_rk[y] && o_rk[x + w] == o_rk[y + w];
    }
    int size() const
    {
        return s_size;
    }
    void init(char * str, bool h = true)
    {
        int n = strlen(str);
        s_size = n;
        p_s = str - 1;
        int cnt_size = max(256, n) + 1;
        vector<int> vid, vpx, vcnt;
        sa.resize(n + 1);
        rk1.clear(); rk1.resize(n * 2 + 2);
        rk2.clear(); rk2.resize(n * 2 + 2);
        vid.resize(n + 1); vpx.resize(n + 1);
        vcnt.resize(cnt_size);
        int* id = &*vid.begin();
        int* px = &*vpx.begin();
        int* cnt = &*vcnt.begin();
        p_rk = &*rk1.begin();
        o_rk = &*rk2.begin();

        int m = 128, p = 0;
        for (int i = 1; i <= n; ++i) ++cnt[p_rk[i] = p_s[i]];
        for (int i = 1; i <= m; ++i) cnt[i] += cnt[i - 1];
        for (int i = n; i >= 1; --i) sa[cnt[p_rk[i]]--] = i;

        for (int w = 1, i; w < n; w <<= 1, m = p)
        {
            // init id
            for (p = 0, i = n; i > n - w; --i)
                id[++p] = i;
            for (int i = 1; i <= n; ++i)
                if (sa[i] > w) id[++p] = sa[i] - w;

            sort(cnt, id, px, n, m);

            // write new rank
            for (std::swap(p_rk, o_rk), p = 0, i = 1; i <= n; ++i)
                p_rk[sa[i]] = cmp(sa[i], sa[i - 1], w) ? p : ++p;
        }
        rk = p_rk;
        if (n == 1) sa[1] = rk[1] = 1;
        if (h) create_height();
    }
    void sort(int* cnt, int* id, int* px, int n, int m)
    {
        memset(cnt, 0, sizeof(int) * (m + 1));
        for (int i = 1; i <= n; ++i)
            ++cnt[px[i] = p_rk[id[i]]];
        for (int i = 1; i <= m; ++i)
            cnt[i] += cnt[i - 1];
        for (int i = n; i >= 1; --i)
            sa[cnt[px[i]]--] = id[i];
    }
    void create_height()
    {
        int n = s_size;
        ht.resize(n + 1);
        for (int i = 1, k = 0; i <= n; ++i)
        {
            if (k) --k;
            while (p_s[i + k] == p_s[sa[p_rk[i] - 1] + k])
                ++k;
            ht[p_rk[i]] = k;
        }
    }
};
Avatar
抱抱熊

一个喜欢折腾和研究算法的大学生

Related

comments powered by Disqus