平衡二叉树的lua实现

特征

  • 有序性 所有结点都是左子树比自己小,右子树比自己大
  • 平衡性 所有结点左右子树高度差绝对值小于等于一。
  • 唯一性 结点key唯一,不能重复。

优势

  • 适合有序性数据存储 二叉树的天然有序性
  • 适合范围查找 基于有序的范围查找时间复杂度为o(log n) + m

缺点

  • 内存地址分散 每个结点都是独立数据块,存储地址分散,操作速度非常依赖存储设备的寻址速度,基于机械硬盘运行将会非常慢。

操作时间复杂度

  • 插入,查询,删除 时间复杂度都是o(log n)

实现

通过动手实现一个平衡二叉树,加深对平衡二叉树的了解,其实二年前用C++也实现过,所以这次想尝试用lua实现,发现用lua实现和用c++实现的区别是,在删除和增加结点的的时候,因为c++有指针,所以可以直接这样写 node = new(tree_node),lua
没有就只能parent.left = new(tree_node),所以递归函数传参需要把父节点的table传递进入,而c++只需要传递父节点成员的指针引用或者二级指针,实现的时候一定要理清逻辑,脑子里想不明白的时候就画图,特别注意别出现环引用。

一个平衡二叉树包含基本插入,删除,查询三个API,lua的实现我增加了一个range范围查询的API。

结构定义

标准的二叉树定义,左右子结点,结点高度,k,v值,基于k有序。

1
2
3
4
5
6
7
8
9
local function new_node(k,v)
return {
left = nil,
right = nil,
depth = 1,
k = k,
v = v,
}
end

查询

开胃菜查询,查询是平衡二叉树里面最简单的,根据k值决定是去左边找还是右边找,类似二分查找。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
local function find_node(node,k)
if node.k == k then
return node.v
end

if node.k > k then
if node.left then
return find_node(node.left,k)
end
else
if node.right then
return find_node(node.right,k)
end
end

return nil
end

比例下图👇查找4,路径就是root[5]->left[3]->right[4]

插入

普通的二叉树插入也非常简单,跟查询一样的寻路,找到空的树枝挂上去就行。平衡二叉树就需要操作平衡调整。
为什么要做平衡操作,试想一下,普通二叉树我们插入1到10,会一直插入到最右边的结点,就会退化成链表了。

此时的二叉树就跟链表差不多,查询,插入,删除基本上就是o(n)的时间复杂度了。
我们再看看平衡二叉树。

  • 失衡 所有结点左右子树高度差绝对值小于等于一。
  • 检测失衡 插入结点经过的路径结点都需要检测。按照压栈的顺序执行,先入后出
  • 更新树高 插入结点经过的路径结点都需要更新树高,检测失衡前要先更新树高,按照压栈的顺序执行,先入后出

插入结点3后,树失衡,此时对二叉树进行左旋调整,插入结点4,5后又再次失衡,再次进行左旋调整
这里我提到了左旋调整右旋调整,我们需要列出树的失衡状态,根据失衡的定义我们可以列出如下6种失衡4种调整

我们只需要实现左旋调整右旋调整先左后右先右后左就是调用左、右调整,只不过是先调整子节点,再调整父节点。

我们以左旋调整为例,看图。

调整过程我们需要改变3个指针指向。

  • 父指针 从指向自身改为指向右子节点(从指向1改为指向3)
  • 自身右指针 从指向自身右子节点改为指向右子节点的左子节点(从指向3改为指向2)
  • 右子结点左指针 从指向右子节点左子结点改为指向自身(从指向2改为指向1)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
--获取高度
local function get_depth(node)
if not node then
return 0
end
return node.depth
end
--更新高度
local function update_depth(node)
local l_depth = get_depth(node.left)
local r_depth = get_depth(node.right)
node.depth = l_depth > r_depth and l_depth + 1 or r_depth + 1
end
--获取平衡值
local function get_node_balance(node)
local l_depth = get_depth(node.left)
local r_depth = get_depth(node.right)
return l_depth - r_depth
end
--左旋
local function ll_rotate(parent,node)
local son = node.right
--父指针change
if parent.root then
parent.root = son
elseif parent.left == node then
parent.left = son
else
parent.right = son
end
local son_l = son.left
--右子结点左指针change
node.right.left = node
--自身右指针change
node.right = son_l
--更新结点高度
update_depth(node)
update_depth(son)
end
--右旋
local function rr_rotate(parent,node)
local son = node.left
--父指针change
if parent.root then
parent.root = son
elseif parent.left == node then
parent.left = son
else
parent.right = son
end
local son_r = son.right
--左子结点右指针change
node.left.right = node
--自身右指针change
node.left = son_r
--更新结点高度
update_depth(node)
update_depth(son)
end

--先左后右
local function lr_rotate(parent,node)
ll_rotate(node,node.left)
rr_rotate(parent,node)
end
--先右后左
local function rl_rotate(parent,node)
rr_rotate(node,node.right)
ll_rotate(parent,node)
end
--左边失衡
local function balance_left(parent,node)
local l_balance = get_node_balance(node.left)
if l_balance > 0 then
rr_rotate(parent,node)
else
lr_rotate(parent,node)
end
end
--右边失衡
local function balance_right(parent,node)
local r_balance = get_node_balance(node.right)
if r_balance > 0 then
rl_rotate(parent,node)
else
ll_rotate(parent,node)
end
end
--检测失衡并调整
local function avl_node(parent,node)
local balance = get_node_balance(node)
if balance > 1 then
balance_left(parent,node)
elseif balance < -1 then
balance_right(parent,node)
end
end

local function add_node(parent,node,k,v)
if node.k == k then
return
end

if node.k > k then
if node.left then
add_node(node,node.left,k,v)
else
node.left = new_node(k,v)
end
else
if node.right then
add_node(node,node.right,k,v)
else
node.right = new_node(k,v)
end
end

update_depth(node)
avl_node(parent,node)
end

完整的插入过程可以分为2个步骤,按照压栈的过程,先入后出。

  • 进的时候栈会记录路径
  • 把路过的结点进行更新树高检测失衡和调整失衡

删除

删除结点有4种情况需要考虑🤔

  1. 叶子结点 没有左右子结点的,直接删除就行。
  2. 仅有左子结点 让父节点继承。
  3. 仅有右子结点 让父节点继承。
  4. 左右子结点都有 中序遍历遍历的前继或者后继结点代替自己。

如何找前继或者后继结点
如图,我们看下中序遍历结点的位置情况。

  • 前继结点位置 左结点或者左节点的最右结点。
  • 后继结点位置 右结点或者右节点的最左结点。

删除结点的过程中我们不能忘了把路过的结点进行更新更新树高检测失衡和调整失衡
删除比较特殊的是删除左右子结点都有的结点时,会进行前继结点位置或者后继结点位置位置的深入,这个路径的结点都需要更新树高

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
local function del_node(parent,node,k,v)
local function del(p,n,next)
if p.root then
if next then
p.root = next
else
p.root = nil
end
elseif p.left == n then
if next then
p.left = next
else
p.left = nil
end
else
if next then
p.right = next
else
p.right = nil
end
end
end
local res = false
if node.k == k then
if not node.left and not node.right then
del(parent,node)
elseif not node.left then
del(parent,node,node.right)
node.right = nil
elseif not node.right then
del(parent,node,node.left)
node.left = nil
else
--找node在中序遍历的前继节点或者后继节点
--我这里找前继,前继结点是左节点或者左节点的最右节点
local pp_node = nil
local pre_node = node.left
while pre_node.right do
pp_node = pre_node
pre_node = pre_node.right
end
del(parent,node,pre_node)
pre_node.right = node.right
node.right = nil
if pp_node then
pp_node.right = pre_node.left
pre_node.left = node.left
end

node.left = nil

local uplist = {pre_node}
local unode = pre_node.left
while unode do
table.insert(uplist,unode)
unode = unode.right
end
for i = #uplist,1,-1 do
update_depth(uplist[i])
end
end
return true
end

if node.k > k then
if node.left then
res = del_node(node,node.left,k,v)
end
else
if node.right then
res = del_node(node,node.right,k,v)
end
end

if res then
update_depth(node)
avl_node(parent,node)
end

return res
end

范围查询

范围查询需要注意的是要用中序遍历,保证结果的有序性就行了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
local function find_by_range(node,b_key,e_key,res_list)
if not node then return end

if node.left and b_key < node.k then
find_by_range(node.left,b_key,e_key,res_list)
end

if node.k >= b_key and node.k <= e_key then
table.insert(res_list,node.k)
table.insert(res_list,node.v)
end

if node.right and e_key > node.k then
find_by_range(node.right,b_key,e_key,res_list)
end
end

完整代码

觉得写得不错,给个星星,非常感谢O(∩_∩)O哈哈~


平衡二叉树的lua实现
https://huahua132.github.io/2023/05/21/data_struct/avltree/
作者
huahua132
发布于
2023年5月21日
许可协议