回文树
回文树(EER Tree,Palindromic Tree),有点类似Trie,但它并不是匹配字符串的,很多人称之为回文自动机,但它一点也不像自动机,不过我还是按习惯的来,使用PAM为简称。为了表示一个回文,我们只表示一边的一个单链即可,这时就类似Trie。但不同之处是,回文区分奇数长度和偶数长度,所以这里我们使用两个根,分别来表示奇数长度和偶数长度。所以,在奇数根里,链ba表示aba,而在偶数根里的ba表示abba。
首先我们来直观地看看PAM的图形化,以下是字符串abcbbc
的PAM
graph TD;
linkStyle default interpolate basis
subgraph root
0-.->1[-1]
end
subgraph node0
0-->6((bb))
6-->7((cbbc))
end
subgraph node1
1-->2((a))
1-->3((b))
1-->4((c))
4-->5((bcb))
end
2-.->0
3-.->0
4-.->0
6-.->3
5-.->3
7-.->4
style 0 fill:#f9f
style 1 fill:#f9f
实线方向就是子节点方向,虚线是fail指针,指向这个节点最长的回文后缀节点。图有点乱,但又不希望画得过于简单导致说不清楚,将就一下吧。
为了能顺利构造,每个节点上面要存储以下必要数据:
- next: 类似Trie,表示子结点的指针,即图上实线
- fail: fail指针,即图上虚线
- len: 这个节点所表示的回文串的长度
- cnt: 这个节点所表示的回文串在原串中出现的次数
- num: 这个节点所表示的回文串中,有多少个后缀也是回文
构造方法有点类似AC自动机或SAM,它用增量式构造,假设前k个字符已经构造好,最后构造的节点假设是aba
,新加的字符是c
,那么我们通过那个节点,获取长度是3,于是判断一下原字符串在aba
的前面的字符看是不是和新字符一样,如果是,就在它下面新加入节点,否则,就跳到它的fail节点再做相同的判断,其实这个比SAM简单多了。
PAM构造过程模拟
我们使用字符串abcbbc
作为开始,首先先把abc都插入
1 插入abc
graph TD;
linkStyle default interpolate basis
subgraph root
0-.->1[-1]
end
subgraph node1
1-->2((a))
1-->3((b))
1-->4((c))
end
2-.->0
3-.->0
4-.->0
style 0 fill:#f9f
style 1 fill:#f9f
2 再插入b
graph TD;
linkStyle default interpolate basis
subgraph root
0-.->1[-1]
end
subgraph node1
1-->2((a))
1-->3((b))
1-->4((c))
4-->5((bcb))
end
2-.->0
3-.->0
4-.->0
5-.->3
style 0 fill:#f9f
style 1 fill:#f9f
新增加的b构成bcb回文,于是加在c后面,而fail指针的查找,就是从c的fail节点开始,判断那个节点对称的位置是不是新增加的字符b,是的话连上,不是的话再向上一层的fail节点继续找。
3 再插入b
graph TD;
linkStyle default interpolate basis
subgraph root
0-.->1[-1]
end
subgraph node0
0-->6((bb))
end
subgraph node1
1-->2((a))
1-->3((b))
1-->4((c))
4-->5((bcb))
end
2-.->0
3-.->0
4-.->0
6-.->3
5-.->3
style 0 fill:#f9f
style 1 fill:#f9f
而最后插入a后的结果和开头的图是相同的,这里不重复了
关于节点计数
这个节点计数,就是前面所说的cnt,即此节点所表示的回文在整个串中出现次数,但这个并不是一次性统计好的,要分两步。首先,在创建回文树过程中的节点计数,然后就是从叶子开始,把自己的cnt数值加到它的fail节点上,为什么是做加法呢,例如节点ababa
,它的fail节点表示的必然是aba
,那么在构建的过程中,每次遇到ababa
的左侧三个字母的时候,就会把cnt加到aba
节点上,而缺少的另一半,正是节点ababa
的出现次数。另外有的博客说cnt是本质不同的回文字符串数量,这是不正确的。
模板
struct PAM
{
struct node
{
map<char, int> next;
int fail, len, cnt, num;
node(int l = 0) : len(l), fail(0), cnt(0), num(0) {}
};
vector<char> s;
vector<node> nodes;
int match_p;
void init(int size)
{
s.clear();
s.reserve(size + 1);
s.push_back('\200');
nodes.push_back(node());
nodes.push_back(node(-1));
nodes[0].fail = 1;
match_p = 0;
}
int getfail(int x)
{
while (s[s.size() - nodes[x].len - 2] != s.back())
x = nodes[x].fail;
return x;
}
void extend(char c)
{
s.push_back(c);
int p = getfail(match_p);
if (!nodes[p].next.count(c))
{
int ch = nodes.size();
nodes.push_back(node(nodes[p].len + 2));
nodes[ch].fail = nodes[getfail(nodes[p].fail)].next[c];
nodes[p].next[c] = ch;
nodes[ch].num = nodes[nodes[ch].fail].num + 1;
}
match_p = nodes[p].next[c];
nodes[match_p].cnt++;
}
void done()
{
for (int i = nodes.size() - 1; i > 0; i--)
nodes[nodes[i].fail].cnt += nodes[i].cnt;
}
void build(const char* s)
{
init(strlen(s));
for (;*s;++s) extend(*s);
done();
}
};
直接调用build即可,得到的nodes就是回文树
PAM应用
1 求字符串中所有本质不同的回文子串的数量
生成PAM统计非根的节点数量即可
2 求字符串中所有回文子串的数量
生成PAM后,对除了根外的所有节点的cnt累加就行了
3 求公共回文子串数量
对两个字符串都生成PAM后,同时dfs遍历其相同的节点,累加两边节点的cnt的乘积即可