而 AC 自动机很好的解决了这个问题,AC 自动机(Aho-Corasick Automaton)是一种同样利用了 KMP 思想,以 Tire 树为结构的多模式匹配算法。
算法过程
AC 自动机的过程分为两步:建立 Trie 树以及预处理失配指针。
建立 Trie 树
不多赘述,Trie 树上的每一个节点表示某一个模式串的前缀。
1 2 3 4 5 6 7 8 9 10
inlinevoidinsert(char *str) { int cur = 0, len = strlen(str); for (int i = 0; i < len; i ++) { int k = str[i] - 'a'; if (!tr[cur][k]) tr[cur][k] = ++ idx; cur = tr[cur][k]; } }
处理失配指针
失配指针,又称为 fail 指针,本质即为 KMP 算法中的 next 辅助数组。
我们知道,KMP 算法中的 next[] 中记录着每一个前缀字符串的最长 border 值,AC 自动机的 fail 数组同样如此。failu 表示状态 u 的指针,指向着状态 u 的最大 border 值所处的状态(因为存在多个模式串,他们的 border 相互交叉),然而 Trie 树上每一个节点的状态刚好就是每一个模式串可能的所有前缀集合。
像这样,后面的 fail 如果跳到了 cur 上的 sonk,就会直接跳向 failcur 上的 sonk 了。
之所以不用考虑边界问题,是因为不存在的默认置 0 了。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
inlinevoidbuild() { hh = 0, tt = -1; for (int k = 0; k < 26; k ++) if (tr[0][k]) fail[tr[0][k]] = 0, que[++ tt] = tr[0][k]; while (hh <= tt) { int cur = que[hh ++]; for (int k = 0; k < 26; k ++) { if (tr[cur][k]) fail[tr[cur][k]] = tr[fail[cur]][k], que[++ tt] = tr[cur][k]; else tr[cur][k] = tr[fail[cur]][k]; } } }
多模式匹配
对于查询的文本串 T,我们只需要在处理好的 Trie 树上走一遍,对于每一个走到的状态,这个状态都是 T 的前缀,因此这个状态和包括所有的 fail 指针都匹配上了 T,按照要求处理即可。
intquery(char *str) { int cur = 0, ans = 0, len = strlen(str); for (int i = 0; i < len; i ++) { int k = str[i] - 'a'; cur = tr[cur][k]; for (int temp = cur; temp && ~cnt[temp]; temp = fail[temp]) ans += cnt[temp], cnt[temp] = -1; } return ans; }
voidquery(char *str) { int cur = 0, len = strlen(str); for (int i = 0; i < len; i ++) { cur = tr[cur][str[i] - 'a']; for (int temp = cur; temp; temp = fail[temp]) ans[rec[temp]].cnt ++; } }
实际上这样的复杂度是错误的,和 KMP 算法中的一样,不断地跳 next 时间复杂度为 O(∣S∣),因此此时的 AC 自动机时间复杂度达到了 O(∑∣S∣×∣T∣),因此我们需要优化。
structACAutomaton { int tr[N][26], fail[N], idx; int que[N], hh = 0, tt = -1;
inlinevoidinsert(char *str, int id) { int cur = 0, len = strlen(str); for (int i = 0; i < len; i ++) { int k = str[i] - 'a'; if (!tr[cur][k]) tr[cur][k] = ++ idx; cur = tr[cur][k]; } rec[id] = cur; }
inlinevoidbuild() { hh = 0, tt = -1; for (int k = 0; k < 26; k ++) if (tr[0][k]) fail[tr[0][k]] = 0, que[++ tt] = tr[0][k]; while (hh <= tt) { int cur = que[hh ++]; for (int k = 0; k < 26; k ++) { if (tr[cur][k]) fail[tr[cur][k]] = tr[fail[cur]][k], que[++ tt] = tr[cur][k]; else tr[cur][k] = tr[fail[cur]][k]; } } for (int ver = 1; ver <= idx; ver ++) add(fail[ver], ver); }
voidquery(char *str) { int cur = 0, len = strlen(str); for (int i = 0; i < len; i ++) { cur = tr[cur][str[i] - 'a']; sum[cur] ++; } } } AC;
voiddfs(int ver) { for (int i = h[ver], to = e[i]; ~i; i = ne[i], to = e[i]) { dfs(to); sum[ver] += sum[to]; } }