AC自动机

听到AC自动机很多人第一次听到的反应往往是很兴奋的。但其实并不是你们想的那种东西。它的全称是Aho-Corasick algorithm,另外,自动机的英文是Automaton,所以AC自动机即 AC Automaton。为了解释这个算法,首先我们来回顾KMP,你需要很理解KMP的原理,不然看后面的内容就会变得妙不可读

KMP自动机

本质上KMP其实就是一种自动机。这次我们改用自动机的形式来理解。所谓自动机,一般指的是确定有限状态自动机,你可以看作一个黑箱,每次输入一个数据,它就会改变它的内部状态,并有相应的输出。如果你知道Trie,那么它其实就是一个典型的自动机。我们还是拿字符串abacabab作为例子,如果是生成next数组,结果如下:

string a b a c a b a b \0
next -1 0 0 1 0 1 2 3 2

为了方便变成自动机的方式理解,我们把这个改成有向图

graph LR;
linkStyle default interpolate basis
0[Start]--a-->00[1]
00--b-->1[2]
1--a-->2[3]
2--c-->3[4]
3--a-->4[5]
4--b-->5[6]
5--a-->6[7]
6--b-->7[8]
00-.->0
1-.->0
2-.->00
3-.->0
4-.->00
5-.->1
6-.->2
%%7[b]-.->3[c]
style 0 fill:#f9f,stroke-dasharray: 5, 5
style 7 fill:#f9f,stroke-dasharray: 5, 5

上图中,实箭头表示匹配,虚箭头表示不匹配要返回的前面的节点,紫色节点表示起止节点。首先我们的状态只要一个指针,先指向start,在匹配的时候,如果与它的下一个字符匹配,那么指针就沿实箭头移动;如果与下一个字符不匹配,在有虚线的情况下,那就沿虚线走一步,然后再尝试一次匹配。以下我们模拟一下匹配”ababa”的过程。

  1. 初始化指向start,start的下一个是a,匹配第1个字符,指针移动到节点1
  2. 节点1的下一个是b,匹配第2个字符,指针移动到节点2
  3. 节点2的下一个是a,匹配第3个字符,指针移动到节点3
  4. 节点3的下一个是c,不匹配第4个字符,回退到节点1,这时候匹配,指针移动到节点2
  5. 节点2的下一个是a,匹配第5个字符,指针移动到节点3

以上的虚线箭头就是fail指针的指向

KMP自动机的生成

这个的生成规则其实非常简单,首先建立start,然后向右添加字符,用实箭头连接,第一个节点就虚箭头直接指回start

graph LR;
linkStyle default interpolate basis
0[Start]--a-->00[1]
00-.->0
style 0 fill:#f9f,stroke-dasharray: 5, 5

然后插入第二个字符,第二个字符的虚箭头看它的父节点的虚箭头所指向的节点的子节点是不是与第二个字符相等,这个描述有点绕,假设当前节点是c,父节点是p,其虚箭头所指节点是fail[x],那么看的是fail[p]的子节点是否等于c的值,如果等于,那么fail[c] = fail[p]->next,如果不等于,那么令p=fail[p]再次判断其子节点,直到p=start,于是加入第二节点时,它的fail指针指向start

graph LR;
linkStyle default interpolate basis
0[Start]--a-->00[1]
00--b-->1[2]
00-.->0
1-.->0
style 0 fill:#f9f,stroke-dasharray: 5, 5

到第3个字符a的时候,它的父节点的fail指向start,而start的子节点也是a,所以它的a就指向第1个字符

graph LR;
linkStyle default interpolate basis
0[Start]--a-->00[1]
00--b-->1[2]
1--a-->2[3]
00-.->0
1-.->0
2-.->00
style 0 fill:#f9f,stroke-dasharray: 5, 5

最后讲一下插入最后的字符b的时候,首先它的父节点的fail指向第3个节点a,但那个节点的下一个字符并不是b,所以用它的父节点的fail替代,然后再看第3个节点的fail,指向第1个节点,而第1个节点的下一个字符是b,所以要指向第2个节点,得到下图

graph LR;
linkStyle default interpolate basis
0[Start]--a-->00[1]
00--b-->1[2]
1--a-->2[3]
2--c-->3[4]
3--a-->4[5]
4--b-->5[6]
5--a-->6[7]
6--b-->7[8]
00-.->0
1-.->0
2-.->00
3-.->0
4-.->00
5-.->1
6-.->2
7-.->1
style 0 fill:#f9f,stroke-dasharray: 5, 5
style 7 fill:#f9f,stroke-dasharray: 5, 5

在弄懂以上过程后,我们来正式介绍AC自动机

AC自动机与Trie

AC自动机的本质,就是在Trie上套KMP,就这一句话。我们怎么理解呢,其实Trie能匹配字符串的前缀,但如果我们需要匹配任意位置,又不希望回溯,那就依照KMP的方法,在匹配失败的时候,跳转到假如回溯能匹配到的Trie的位置。为了更好说明,这里我们使用he,she,the,there,here来演示生成过程。

先生成Trie

graph LR;
linkStyle default interpolate basis
0((0))--h-->1
1--e-->2((2))
2--r-->3
3--e-->4((4))
0--s-->5
5--h-->6
6--e-->7((7))
0--t-->8
8--h-->9
9--e-->10((10))
10--r-->11
11--e-->12((12))
style 0 fill:#f9f,stroke-dasharray: 5, 5
style 2 fill:#f9f,stroke-dasharray: 5, 5
style 4 fill:#f9f,stroke-dasharray: 5, 5
style 7 fill:#f9f,stroke-dasharray: 5, 5
style 10 fill:#f9f,stroke-dasharray: 5, 5
style 12 fill:#f9f,stroke-dasharray: 5, 5

然后做BFS,第一层的fail都指向0

graph LR;
linkStyle default interpolate basis
0((0))--h-->1
1--e-->2((2))
2--r-->3
3--e-->4((4))
0--s-->5
5--h-->6
6--e-->7((7))
0--t-->8
8--h-->9
9--e-->10((10))
10--r-->11
11--e-->12((12))
1-.->0
5-.->0
8-.->0
style 0 fill:#f9f,stroke-dasharray: 5, 5
style 2 fill:#f9f,stroke-dasharray: 5, 5
style 4 fill:#f9f,stroke-dasharray: 5, 5
style 7 fill:#f9f,stroke-dasharray: 5, 5
style 10 fill:#f9f,stroke-dasharray: 5, 5
style 12 fill:#f9f,stroke-dasharray: 5, 5

然后,第二层,与KMP自动机的建立规则相同,另外为了让图形上的线不那么乱,虚线指向start的省略

graph LR;
linkStyle default interpolate basis
0((0))--h-->1
subgraph here
1--e-->2((2))
2--r-->3
3--e-->4((4))
end
0--s-->5
subgraph she
5--h-->6
6--e-->7((7))
end
0--t-->8
subgraph there
8--h-->9
9--e-->10((10))
10--r-->11
11--e-->12((12))
end
6-.->1
9-.->1
style 0 fill:#f9f,stroke-dasharray: 5, 5
style 2 fill:#f9f,stroke-dasharray: 5, 5
style 4 fill:#f9f,stroke-dasharray: 5, 5
style 7 fill:#f9f,stroke-dasharray: 5, 5
style 10 fill:#f9f,stroke-dasharray: 5, 5
style 12 fill:#f9f,stroke-dasharray: 5, 5

接着,第三层

graph LR;
linkStyle default interpolate basis
0((0))--h-->1
subgraph here
1--e-->2((2))
2--r-->3
3--e-->4((4))
end
0--s-->5
subgraph she
5--h-->6
6--e-->7((7))
end
0--t-->8
subgraph there
8--h-->9
9--e-->10((10))
10--r-->11
11--e-->12((12))
end
6-.->1
9-.->1
7-.->2
10-.->2

style 0 fill:#f9f,stroke-dasharray: 5, 5
style 2 fill:#f9f,stroke-dasharray: 5, 5
style 4 fill:#f9f,stroke-dasharray: 5, 5
style 7 fill:#f9f,stroke-dasharray: 5, 5
style 10 fill:#f9f,stroke-dasharray: 5, 5
style 12 fill:#f9f,stroke-dasharray: 5, 5

构建完毕的图

graph LR;
linkStyle default interpolate basis
0((0))--h-->1
subgraph here
1--e-->2((2))
2--r-->3
3--e-->4((4))
end
0--s-->5
subgraph she
5--h-->6
6--e-->7((7))
end
0--t-->8
subgraph there
8--h-->9
9--e-->10((10))
10--r-->11
11--e-->12((12))
end
6-.->1
9-.->1
7-.->2
10-.->2
11-.->3
12-.->4

style 0 fill:#f9f,stroke-dasharray: 5, 5
style 2 fill:#f9f,stroke-dasharray: 5, 5
style 4 fill:#f9f,stroke-dasharray: 5, 5
style 7 fill:#f9f,stroke-dasharray: 5, 5
style 10 fill:#f9f,stroke-dasharray: 5, 5
style 12 fill:#f9f,stroke-dasharray: 5, 5

以上就是一个最简单的AC自动机,由于只能通过fail指针在失配时做转移,所以遇到匹配失败的时候不能一步到位,需要一个循环来找下一个位置,但在不少场合已足够使用。习题:HDU-2896

字典图

前面刚说过,由于fail指针只有一个,所以遇到匹配失败的时候不能一步到位,那我们如果想一步到位呢?那事实上就成为了一个有向图,我们在跳转时不使用fail指针,而直接用next指针替代,每遇到一个字符就按next来跳转,这样状态转移时间非常稳定且速度更快,而且成为有向图有一个额外的好处,就是能变成图论问题来解,这个后面再来讨论。

要实现字典图,fail指针还是需要的,但在构建的时候代码写起来反而更简单,因为fail的指向不再需要写循环,可以利用前面的结果一步到位,假设当前节点是c,要更新的字符是i,那分两种情况:

  1. 如果c的next[i]非空,那么c的next[i]节点的fail指针就指向 c的fail指针节点的next[i]
  2. 如果c的next[i]为空,那么c的next[i]节点就指向 c的fail指针节点的next[i]

也就是说,不管哪种,都是指向c的fail的next[i]

模板

应用以下模板时,你很可能需要做的调整包括charset的大小,以及getindex函数的实现,这两部分你也可以通过template改写

const int charset = 26;
struct TrieGraph
{
    struct trie_node
    {
        int next[charset];
        int fail;
        int cnt;
        int end;
        trie_node() : end(0), fail(0), cnt(0) {}
        void init() { memset(next, 0, sizeof(next)); }
    };
    vector<trie_node> nodes;
    vector<int> bfs_q;
    vector<int> match_cnt;
    map<int, int> id2node;
    map<int, int> str_size;
    int match_p;
    void init(int size)
    {
        nodes.clear();
        nodes.reserve(size);
        nodes.push_back(trie_node());
        nodes.back().init();
        match_p = 0;
    }
    static inline int getindex(char c) { return c - 'a'; }
    void insert(const char* s, int id)
    {
        const char* s0 = s;
        int p = 0;
        for (;*s; ++s)
        {
            if (nodes[p].next[getindex(*s)])
            {
                p = nodes[p].next[getindex(*s)];
            }
            else
            {
                int np = nodes.size();
                nodes[p].next[getindex(*s)] = np;
                nodes.push_back(trie_node());
                nodes.back().init();
                p = np;
            }
        }
        ++nodes[p].cnt;
        id2node[id] = p;
        str_size[id] = s - s0;
    }
    void build()
    {
        bfs_q.clear();
        bfs_q.reserve(nodes.size());
        queue<int> q;
        for (int i = 0; i < charset; ++i)
            if (nodes[0].next[i]) q.push(nodes[0].next[i]);
        while (!q.empty())
        {
            int p = q.front();
            q.pop();
            bfs_q.push_back(p);
            for (int i = 0; i < charset; ++i)
            {
                if (nodes[p].next[i])
                {
                    nodes[nodes[p].next[i]].fail = nodes[nodes[p].fail].next[i];
                    q.push(nodes[p].next[i]);
                }
                else nodes[p].next[i] = nodes[nodes[p].fail].next[i];
            }
        }
    }
    int match(char c)
    {
        match_p = nodes[match_p].next[getindex(c)];
        return match_p;
    }
    int query(const char* s) // 有多少个出现
    {
        int ret = 0;
        match_cnt.resize(nodes.size());
        for (int i = match_cnt.size() - 1; i >= 0; --i)
        {
            match_cnt[i] = nodes[i].cnt;
        }
        for (const char* ps = s; *ps; ps++)
        {
            for (int p = match(*ps); p && ~match_cnt[p]; p = nodes[p].fail)
                ret += match_cnt[p], match_cnt[p] = -1;
        }
        return ret;
    }
    ll query_sum(const char* s, int wc[]) // 每个分别出现多少
    {
        vector<int> sum;
        sum.resize(nodes.size());
        match_p = 0;
        for (const char* ps = s; *ps; ps++)
        {
            for (int p = match(*ps); p; p = nodes[p].fail)
            {
                sum[p] += nodes[p].cnt;
            }
        }
        ll ret = 0;
        for (map<int, int>::iterator it = id2node.begin(); it != id2node.end(); ++it)
        {
            wc[it->first] = sum[it->second];
            ret += sum[it->second];
        }
        return ret;
    }
    const char* find(const char* s, int& match_id)
    {
        match_p = 0;
        for (const char* ps = s; *ps; ps++)
        {
            for (int p = match(*ps); p; p = nodes[p].fail)
            {
                if (nodes[p].cnt == 0) continue;
                for (map<int, int>::iterator it = id2node.begin(); it != id2node.end(); ++it)
                {
                    if (it->second == p)
                    {
                        match_id = it->first;
                        break;
                    }
                }
                return ps - str_size[match_id] + 1;
            }
        }
        return 0;
    }
};

使用方式,先调用init预分配空间,然后调用insert插入所有用到的字符串,注意字符串的id必须从1开始,再调用build生成字典树,最后调用query匹配目标字符串,如果只需要知道有多少个串在目标中出现,那么调用单个参数的,如果需要知道每个分别出现多少次,那么使用有wc参数的版本,通过参数wc返回的是原始字符串每一个的匹配数量,而如果只需要找最初匹配的位置,那用find函数。习题 HYSBZ-3172

扩展:字符串生成的可能数量

典型题目为POJ-2778,即生成长度为n的字符串,且不包含给定的m个子串。这时候就要用上前面所构造的字典图,我们要先转成邻接矩阵,例如m[i][j][c]如果为1,表示节点i能通过字符c连接到节点j,为0则不通。不过实际计算的时候,我们并不关心i和j之间通过什么连接,只关心连接数量,那累加m[i][j][c],c取字符集的范围,累加值写到矩阵M = matrix[i][j],然后我们只要计算$M’ = M ^ n$,在矩阵M'中,M'[i][j]的值就表示从i到j恰好n步共有多少种走法。回到上面的题目,因为部分节点不能走,所以我们只要在生成矩阵的时候删除那些不能走的节点,求出M'后,累加M'[0][j]的结果就是答案,所以算法复杂度是 $O(Mlogn)$ ,其中M是一次矩阵乘法的时间复杂度,这个也是有向图里面求k步到达指定节点的路线数量所用的算法。

Avatar
抱抱熊

一个喜欢折腾和研究算法的大学生

Related

comments powered by Disqus