字典树的应用-存储二进制信息

leetcode 周赛中遇到的一道需要使用 字典树的题目,总结分析一下。

题目:查询最大基因差, 给你一棵 n 个节点的有根树,节点编号从 0 到 n - 1。每个节点的编号表示这个节点的 独一无二的基因值 (也就是说节点 x 的基因值为 x)。两个基因值的 基因差 是两者的 异或值 。给你整数数组 parents ,其中 parents[i] 是节点 i 的父节点。如果节点 x 是树的 根 ,那么 parents[x] == -1 。

给你查询数组 queries ,其中 queries[i] = [nodei, vali] 。对于查询 i ,请你找到 vali 和 pi 的 最大基因差 ,其中 pi 是节点 nodei 到根之间的任意节点(包含 nodei 和根节点)。更正式的,你想要最大化 vali XOR pi

请你返回数组 ans ,其中 ans[i] 是第 i 个查询的答案。

分析

题目中,将结点之间的树形关系通过 parent 数组进行存储,对于该结构,我们可以从一个结点出发,非常方便的寻找到树的根部。因此,最简单的想法是,直接从 nodei 出发,遍历从 nodei 到根结点之间的所有结点,选取其中 与 vali 异或值最大的值作为查询的答案即可。该暴力解法代码如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Solution {
public int[] maxGeneticDifference(int[] parents, int[][] queries) {
int[] res = new int[queries.length];
for(int i = 0; i < queries.length; ++i){
int[] query = queries[i];
int node = query[0];
int val = query[1];
while(node != -1){
res[i] = Math.max(res[i], val ^ node);
node = parents[node];
}
}
return res;
}
}

不出所料,暴力方法直接超过时间限制。那么如何降低算法的时间复杂度呢?

一般地,对于这种比较型的题目,降低时间复杂度的方法是 在遍历比较的过程中,保存前面比较的一些信息,用于后续 使用(例如 kmp 算法)。对于这里的树形结构,结点与它的父结点之间是有重复信息的,如下图所示, 结点 4 需要遍历 蓝色部分包围的 4 个结点,而其父节点需要遍历的结点与结点4 有3个结点为重复的。

因此,一种想法是 使用 dfs 进行遍历,将遍历前序结点时的“信息“保存下来,后续结点只需要添加一些增加信息就可以进行判断。这里又有一个难点,就是该如何保存遍历到的结点的信息?字典树结构可以在这里派上用场。

具体地,我们使用字典树来保存结点的二进制信息,将一个数字保存到 字典树的过程如下:字典树以 root 为起点,待添加的数字从其二进制最高位开始依次向低位进行遍历,如果该位上数字为1,则 root 的左节点加1,并将左节点设置为新的 root;如果该位上数字为0,则root 的右结点加1,并将右结点设置为新的root。如下图所示。

FigSource

将结点信息存放到树中之后,求解与特定 val 值异或结果的最大值就非常简单了,具体的,我们设异或结果的最大值为 res,遵循下面的规则确定各二进制位上的数字即可:

  • 如果 val 当前二进制位上的值为 1:
    • 若字典树中当前节点有 右子节点(有0),则 res 该位上为 1,root 结点变更为 右子结点;
    • 若字典树中当前节点无 右子节点,res 该位上为 0,root 结点变更为 左子结点;
  • 如果 val 当前二进制位上的值为 0:
    • 若字典树中当前节点有 左子节点(有1),则 res 该位上为 1,root 结点变更为 左子节点;
    • 若字典树中当前节点无 左子节点,res 该位上为 0,root 结点变更为 右子节点;

代码分析

整个的代码主要分为三个部分:

首先是将题目中给定的信息重新组织,将parent 数组表示的树结构转换为根结点引出子结点的结构,可以使用 map 存储,key 值为根结点值,value 为一个 list,存放根结点的叶子结点。但由于这里的 node 是从0到 n-1 的确定的值,我们也可以使用 数组来存放 树信息,数组下标为 node 的值,数组中的内容为该 node 的子结点。此外,对于查询,我们将针对相同 node 的查询汇总起来,同样放置到一个 map 中,key 值为 node 值,value 为一个list,放置针对该node 的查询。具体代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class Solution {
int[] res;
public int[] maxGeneticDifference(int[] parents, int[][] queries) {
res = new int[queries.length];
int root = -1;
for(int i = 0; i < parents.length; ++i){
if(parents[i] == -1){
root = i;
break;
}
}

// 放置树的信息。变为从根结点到子结点的结构
List<Integer>[] treeInfo = new ArrayList[parents.length];
// 错误写法!!Arrays.fill(treeInfo, new ArrayList<>());
// 这样 ArrayList<>() 是同一个对象引用。
for(int i = 0; i < treeInfo.length; ++i){
treeInfo[i] = new ArrayList<>();
}
for(int i = 0; i < parents.length; ++i){
int parent = parents[i];
if(parent == -1) continue;
treeInfo[parent].add(i);
}

// 将针对相同 node 的 queries 放置到一起
Map<Integer, List<Integer>> queryMap = new HashMap<>();
for(int i = 0; i < queries.length; ++i){
int[] query = queries[i];
List<Integer> list = queryMap.getOrDefault(query[0], new ArrayList<>());
list.add(i);
queryMap.put(query[0], list);
}

TrieTree tt = new TrieTree();
// dfs 寻找
dfs(treeInfo, queryMap, queries, tt, root);
return res;
}
void dfs(List<Integer>[] treeInfo, Map<Integer, List<Integer>> queryMap, int[][] queries,TrieTree tt, int root){
// TODO
}
}

第二部分,是进行树的 dfs 遍历,具体就是以经典的递归方式实现使用程序栈进行遍历。需要注意的是,在遍历的过程中,有一个回溯的过程,因为当前结点仅和他的子结点们公用一些信息,而与其他的结点并不公用信息。因此,完成一个结点的所有子结点的遍历后,需要将该结点的信息从 trie tree 中删除。具体代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
void dfs(List<Integer>[] treeInfo, Map<Integer, List<Integer>> queryMap, int[][] queries,TrieTree tt, int root){
// step1: 将 root 对应的值放入 trie tree
tt.insert(root);

// step2: 遍历查询。计算各个查询的结果
if(queryMap.containsKey(root)){
for(Integer queryIndex : queryMap.get(root)){
int[] query = queries[queryIndex];
res[queryIndex] = tt.search(query[1]);
}
}


// step3: dfs, 遍历 root 的子节点。
if(treeInfo[root] != null){
for(Integer node : treeInfo[root]){
dfs(treeInfo, queryMap, queries, tt, node);
}
}

// step4: 将 root 的值从 trie tree 中删除
tt.delete(root);
}

最后,就是实现一个 trie tree 来从二进制位的角度,将结点值进行存储。我们使用一个单独的类来封装这个结构,并且提供 insert 结点值,delete 结点值,以及查询目标值与当前存储的元素的最大异或值的函数 search。具体代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class TrieTree {
TrieTree left = null;
TrieTree right = null;
TrieTree cur = this;
// 记录当前 trie tree 结点出现的次数
int count = 0;

public void insert(int val){
cur = this;
for(int i = 31; i >= 0; --i){
if(((val >> i) & 0x1) == 0x1){
if(cur.left == null) cur.left = new TrieTree();
cur = cur.left;
cur.count++;
}else{
if(cur.right == null) cur.right = new TrieTree();
cur = cur.right;
cur.count++;
}
}
}

public void delete(int val){
cur = this;
for(int i = 31; i >= 0; --i){
if(((val >> i) & 1) == 1){
cur = cur.left;
cur.count--;
}else{
cur = cur.right;
cur.count--;
}
}
}

public int search(int val){
cur = this;
int res = 0;
for(int i = 31; i >= 0; --i){
if(((val >> i) & 0x1) == 0x1){
if(cur.right != null && cur.right.count > 0){
res += (1 << i);
cur = cur.right;
}else{
cur = cur.left;
}
}else{
if(cur.left != null && cur.left.count > 0){
res += (1 << i);
cur = cur.left;
}else{
cur = cur.right;
}
}
}
return res;
}
}