参考《算法竞赛进阶指南》、AcWing题库

Trie 字典树

Trie (字典树) 是一种用于实现字符串快速检索的多叉树结构。Trie 的每个节点都拥有若干个字符指针, 若在插入或检索字符串时扫描到一个字符 cc, 就沿着当前节点的 cc 字符指针, 走向该指针指向的节点。下面我们来详细讨论 Trie 的基本操作过程。

初始化

一棵空 Trie 仅包含一个根节点, 该点的字符指针均指向空。

插入

当需要插入一个字符串 SS 时, 我们令一个指针 PP 起初指向根节点。然后, 依次扫描 SS 中的每个字符 cc :

  1. PPcc 字符指针指向一个已经存在的节点 QQ, 则令 P=QP=Q
  2. PPcc 字符指针指向空, 则新建一个节点 QQ, 令 PPcc 字符指针指向 QQ, 然后令 P=QP=Q

SS 中的字符扫描完毕时, 在当前节点 PP 上标记它是一个字符串的末尾。

检索

当需要检索一个字符串 SS 在 Trie 中是否存在时, 我们令一个指针 PP 起初指向根节
点, 然后依次扫描 SS 中的每个字符 cc :

  1. PPcc 字符指针指向空, 则说明 SS 没有被插入过 Trie, 结束检索。
  2. PPcc 字符指针指向一个已经存在的节点 QQ, 则令 P=QP=Q

SS 中的字符扫描完毕时, 若当前节点 PP 被标记为一个字符串的末尾, 则说明 SS
在 Trie 中存在, 否则说明 SS 没有被插入过 Trie。

在上图所示的例子中, 需要插入和检索的字符串都由小写字母构成, 所以 Trie 的每个节点具有 26 个字符指针, 分别为 a\mathrm{a}z\mathrm{z} 。上图展示了在一棵空 Trie 中依次插入 “cab” “cos” “car” “cat” “cate” 和 “rain” 后的 Trie 的形态, 灰色标记了单词的末尾节点。可以看出在 Trie 中, 字符数据都体现在树的边 (指针) 上, 树的节点仅保存一些额外信息, 例如单词结尾标记等。其空间复杂度是 O(NC)O(N C), 其中 NN 是节点个数, CC 是字符集的大小。

cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// 假设字符串由小写字母构成
int trie[SIZE][26], tot = 1; // 根节点1, 终止点0

// Trie的插入
void insert(char* str) {
int len = strlen(str), p = 1;
for (int k = 0; k < len; k++) {
int ch = str[k]-'a';
if (trie[p][ch] == 0) trie[p][ch] = ++tot;
p = trie[p][ch];
}
end[p] = true;
}

// Trie的检索
bool search(char* str) {
int len = strlen(str), p = 1;
for (int k = 0; k < len; k++) {
p = trie[p][str[k]-'a'];
if (p == 0) return false;
}
return end[p];
}
cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
int son[MAX_StrLen][26], cnt[MAX_StrLen], idx; // 0 即表示根节点,也是终止点

void insert(char *str) { // 插入字符串
int p = 0;
for (int i = 0; str[i]; ++i) {
int u = str[i] - 'a';
if (!son[p][u]) son[p][u] = ++idx;
p = son[p][u];
}
++cnt[p];
}

int query(char *str) { // 查询字符串出现次数
int p = 0;
for (int i = 0; str[i]; ++i ) {
int u = str[i] - 'a';
if (!son[p][u]) return 0;
p = son[p][u];
}
return cnt[p];
}
cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class Trie {
public:
struct Node {
bool isEnd = false;
int cnt = 0;
Node *son[26] = {nullptr};
} *root;

Trie() {
root = new Node();
}

void insert(string word) {
auto p = root;
for (auto c : word) {
int u = c - 'a';
if (!p->son[u]) p->son[u] = new Node();
p = p->son[u];
}
p->isEnd = true;
p->cnt++;
}

bool search(string word) {
auto p = root;
for (auto c : word) {
int u = c - 'a';
if (!p->son[u]) return false;
p = p->son[u];
}
return p->isEnd;
}

bool startsWith(string prefix) {
auto p = root;
for (auto c : prefix) {
int u = c - 'a';
if (!p->son[u]) return false;
p = p->son[u];
}
return true;
}
};

例题

142. 前缀统计

给定 NN 个字符串 S1,S2SNS_1,S_2…S_N,接下来进行 MM 次询问,每次询问给定一个字符串 TT,求 S1SNS_1 \sim S_N 中有多少个字符串是 TT 的前缀。

输入字符串的总长度不超过 10610^6,仅包含小写字母。

输入格式

第一行输入两个整数 NMN,M

接下来 NN 行每行输入一个字符串 SiS_i

接下来 MM 行每行一个字符串 TT 用以询问。

输出格式

对于每个询问,输出一个整数表示答案。

每个答案占一行。

输入样例:

plaintext
1
2
3
4
5
6
3 2
ab
bc
abc
abc
efg

输出样例:

plaintext
1
2
2
0

算法分析

把这 NN 个字符串插入一棵 Trie 树, Trie 树的每个节点上存储一个整数 cntc n t, 记录该节点是多少个字符串的末尾节点。(为了处理插入重复字符串的情况, 这里要记录个数, 而不能只做结尾标记)

对于每个询问, 在 Trie 树中检索 TT, 在检索过程中累加途径的每个节点的 cntc n t 值, 就是该询问的答案。

cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int SIZE=1000010;
int trie[SIZE][26], tot = 1; // 初始化,假设字符串由小写字母构成
int ed[SIZE];
int n, m;
char str[SIZE];

void insert(char* str) { // 插入一个字符串
int len = strlen(str), p = 1;
for (int k = 0; k < len; k++) {
int ch = str[k]-'a';
if (trie[p][ch] == 0) trie[p][ch] = ++tot;
p = trie[p][ch];
}
ed[p]++;
}

int search(char* str) {
int len = strlen(str), p = 1;
int ans = 0;
for (int k = 0; k < len; k++) {
p = trie[p][str[k]-'a'];
if (p == 0) return ans;
ans += ed[p];
}
return ans;
}

int main() {
cin>>n>>m;
for(int i=1;i<=n;i++) {
scanf("%s",str);
insert(str);
}
for(int i=1;i<=m;i++) {
scanf("%s",str);
printf("%d\n", search(str));
}
}
cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
//Author:XuHt
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
const int N = 1000006;
int trie[N][27], tot = 1;
char s[N];

void add() {
int len = strlen(s), p = 1;
for (int i = 0; i < len; i++) {
int k = s[i] - 'a' + 1;
if (!trie[p][k]) trie[p][k] = ++tot;
p = trie[p][k];
}
++trie[p][0];
}

void get() {
int ans = 0, len = strlen(s), p = 1;
for (int i = 0; i <= len; i++) {
ans += trie[p][0];
p = trie[p][s[i]-'a'+1];
}
cout << ans << endl;
}

int main() {
memset(trie, 0, sizeof(trie));
int n, m;
cin >> n >> m;
while (n--) {
scanf("%s", s);
add();
}
while (m--) {
scanf("%s", s);
get();
}
return 0;
}

Solution


143. 最大异或对

在给定的 NN 个整数 A1A2ANA_1,A_2……A_N 中选出两个进行 xorxor(异或)运算,得到的结果最大是多少?

输入格式

第一行输入一个整数 NN

第二行输入 NN 个整数 A1A_1ANA_N

输出格式

输出一个整数表示答案。

数据范围

1N1051 \le N \le 10^5,
0Ai<2310 \le A_i < 2^{31}

输入样例:

plaintext
1
2
3
1 2 3

输出样例:

plaintext
1
3 

算法分析

我们考虑所有的二元组 (i,j)(i, j)i<ji<j, 那么本题的目标就是在其中找到 Al  xor  AjA_{l}\; \mathrm{xor} \;A_{j} 的最大值。也就是说, 对于每个 i(1iN)i(1 \leq i \leq N), 我们希望找到一个 j(1j<i)j(1 \leq j<i), 使 Ai  xor  AjA_{i} \; \mathrm{xor} \; A_{j} 最大, 并求出这个最大值。

我们可以把每个整数看作长度为 32 的二进制 01 串 (数值较小时在前边补 0 ), 并且把 A1Ai1A_{1} \sim A_{i-1} 对应的 32 位二进制串插入一棵 Trie 树 (其中最低二进制位为叶子节点)。接下来, 对于 AiA_{i} 对应的 32 位二进制串, 我们在 Trie 中进行一次与检索类似的过程, 每一步都尝试沿着 “与 AiA_{i} 的当前位相反的字符指针” 向下访问。若 “与 AiA_{i} 的当前位相反的字符指针” 指向空节点, 则只好访问与 AiA_{i} 当前位相同的字符指针。根据 xor 运算 “相同得 0 , 不同得 1 ” 的性质, 该方法即可找出与 AlA_{l}xor\mathrm{xor} 运算结果最大的 AjA_{j}

如下图所示, 在一棵插入了 2(010),5(101),7(111)2(010), 5(101), 7(111) 三个数的 Trie 中, 分别查询与 6(110),3(011)6(110), 3(011) 做 xor 运算结果最大的数。(为了简便, 图中使用了 3 位二进制数代替 32 位二进制数)

综上所述, 我们可以循环 i=1Ni=1 \sim N, 对于每个 ii, Trie 树中应该存储了 A1Ai1A_{1} \sim A_{i-1} 对应的 32 位二进制串 (实际上每次 ii 增长前, 把 AiA_{i} 插入 Trie 即可)。根据我们刚才提到的 “尽量走相反的字符指针” 的检索策略, 就可以找到所求的 AjA_{j} , 更新答案。

cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
const int SIZE = 100010;
int trie[SIZE * 32 + 5][2], tot = 1; // 初始化,假设字符串由小写字母构成
int a[SIZE], n, ans;

void insert(int val) { // 插入一个二进制数
int p = 1;
for (int k = 30; k >= 0; k--) {
int ch = val >> k & 1;
if (trie[p][ch] == 0) trie[p][ch] = ++tot;
p = trie[p][ch];
}
}

int search(int val) {
int p = 1;
int ans = 0;
for (int k = 30; k >= 0; k--) {
int ch = val >> k & 1;
if (trie[p][ch ^ 1]) { // 走相反的位
p = trie[p][ch ^ 1];
ans |= 1 << k;
} else { // 只能走相同的位
p = trie[p][ch];
}
}
return ans;
}

int main() {
cin >> n;
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
insert(a[i]);
ans = max(ans, search(a[i]));
}
cout << ans << endl;
}
cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
// Author:XuHt
#include <algorithm>
#include <cstdio>
#include <iostream>
using namespace std;
const int N = 100006 * 33;
int trie[N][2];

int main() {
int n;
cin >> n;
int ans = 0, t = 1;
for (int i = 1; i <= n; i++) {
int a;
scanf("%d", &a);
int p = 1;
for (int j = 31; j >= 0; j--) {
int k = (a >> j) & 1;
if (!trie[p][k]) trie[p][k] = ++t;
p = trie[p][k];
}
p = 1;
if (i > 1) {
int w = 0;
for (int j = 31; j >= 0; j--) {
int k = (a >> j) & 1;
if (trie[p][k ^ 1]) {
w = (w << 1) + (k ^ 1);
p = trie[p][k ^ 1];
} else {
w = (w << 1) + k;
p = trie[p][k];
}
}
ans = max(ans, w ^ a);
}
}
cout << ans << endl;
return 0;
}

Solution


144. 最长异或值路径

给定一个树,树上的边都具有权值。

树中一条路径的异或长度被定义为路径上所有边的权值的异或和:

xorlength(p)=epw(e)_{xor}length(p)=\oplus_{e \in p} w(e)

为异或符号。

给定上述的具有 nn 个节点的树,你能找到异或长度最大的路径吗?

输入格式

第一行包含整数 nn,表示树的节点数目。

接下来 n1n-1 行,每行包括三个整数 uvwu,v,w,表示节点 uu 和节点 vv 之间有一条边权重为 ww

输出格式

输出一个整数,表示异或长度最大的路径的最大异或和。

数据范围

1n1000001 \le n \le 100000,
0u,v<n0 \le u,v < n,
0w<2310 \le w <2^{31}

输入样例:

plaintext
1
2
3
4
4
0 1 3
1 2 4
1 3 6

输出样例:

plaintext
1
7 

样例解释

样例中最长异或值路径应为 0->1->2,值为 7(=34)7 (=3 ⊕ 4)

算法分析

D[x]D[x] 表示根节点到 xx 的路径上所有边权的 xor 值, 显然有:

D[x]=D[father(x)] xor weight(x, father(x))D[x]=D[\text {father}(x)] \text { xor weight}(x \text {, father}(x))

根据上式, 我们可以对树进行一次深度优先遍历, 求出所有的 D[x]D[x] 。不难发现, 树上 xxyy 的路径上所有边权的 xor 结果就等于 D[x] xor D[y]D[x] \text{ xor } D[y] 。这是因为根据 xor 运算的性质 (a xor a=0)(a \text{ xor } a=0) ,“ xx 到根 ” 和 “ yy 到根 ” 这两条路径重叠的部分恰好抵消掉。

所以, 问题就变成了从 D[1]D[N]D[1] \sim D[N]NN 个数中选出两个, xor\text{xor} 的结果最大, 即上一道例题。可以用 Trie 树来快速求解。

cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
const int u = 100010;
int ver[2 * u], edge[2 * u], next[2 * u], head[u], v[u], val[u * 32],
a[u * 32][2], f[u];
int n, tot, i, ans, x, y, z;

void add(int x, int y, int z) {
ver[++tot] = y;
edge[tot] = z;
next[tot] = head[x];
head[x] = tot;
}

void dfs(int x) {
v[x] = 1;
for (int i = head[x]; i; i = next[i])
if (!v[ver[i]]) {
f[ver[i]] = f[x] ^ edge[i];
dfs(ver[i]);
}
}

void ins(int x, int y, int temp) {
if (y < 0) {
val[x] = temp;
return;
}
int z = (temp >> y) & 1;
if (!a[x][z]) a[x][z] = ++tot;
ins(a[x][z], y - 1, temp);
}

int get(int x, int y, int temp) {
if (y < 0) return val[x];
int z = (temp >> y) & 1;
if (a[x][z ^ 1])
return get(a[x][z ^ 1], y - 1, temp);
else
return get(a[x][z], y - 1, temp);
}

int main() {
while (cin >> n) {
memset(head, 0, sizeof(head));
memset(f, 0, sizeof(f));
memset(v, 0, sizeof(v));
tot = 0;
for (i = 1; i < n; i++) {
scanf("%d%d%d", &x, &y, &z);
x++, y++;
add(x, y, z);
add(y, x, z);
}
dfs(1);
tot = 1;
ans = 0;
memset(a, 0, sizeof(a));
ins(1, 30, 0);
for (i = 1; i <= n; i++) {
ans = max(ans, f[i] ^ get(1, 30, f[i]));
ins(1, 30, f[i]);
}
cout << ans << endl;
}
return 0;
}
cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
// Author:XuHt
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>
using namespace std;
const int N = 100006;
int n, d[N], trie[N * 33][2], tot;
vector<pair<int, int> > e[N];
int Head[N], Edge[N * 2], Leng[N * 2], Next[N * 2], num;
bool v[N];

void dfs(int x) {
for (int i = Head[x]; i; i = Next[i]) {
int y = Edge[i], z = Leng[i];
if (v[y]) continue;
v[y] = 1;
d[y] = d[x] ^ z;
dfs(y);
}
}

void add(int x, int y, int z) {
Edge[++tot] = y;
Leng[tot] = z;
Next[tot] = Head[x];
Head[x] = tot;
}

void The_xor_longest_Path() {
memset(d, 0, sizeof(d));
memset(trie, 0, sizeof(trie));
memset(v, 0, sizeof(v));
memset(Head, 0, sizeof(Head));
num = 0;
v[0] = 1;
tot = 1;
for (int i = 1; i < n; i++) {
int u, v, w;
scanf("%d %d %d", &u, &v, &w);
add(u, v, w);
add(v, u, w);
}
dfs(0);
int ans = 0;
for (int i = 0; i < n; i++) {
int p = 1;
for (int j = 31; j >= 0; j--) {
int k = (d[i] >> j) & 1;
if (!trie[p][k]) trie[p][k] = ++tot;
p = trie[p][k];
}
p = 1;
if (i) {
int w = 0;
for (int j = 31; j >= 0; j--) {
int k = (d[i] >> j) & 1;
if (trie[p][k ^ 1]) {
w = (w << 1) + (k ^ 1);
p = trie[p][k ^ 1];
} else {
w = (w << 1) + k;
p = trie[p][k];
}
}
ans = max(ans, w ^ d[i]);
}
}
cout << ans << endl;
}

int main() {
while (cin >> n) The_xor_longest_Path();
return 0;
}

Solution