--获取高度 localfunctionget_depth(node) ifnot node then return0 end return node.depth end --更新高度 localfunctionupdate_depth(node) local l_depth = get_depth(node.left) local r_depth = get_depth(node.right) node.depth = l_depth > r_depth and l_depth + 1or r_depth + 1 end --获取平衡值 localfunctionget_node_balance(node) local l_depth = get_depth(node.left) local r_depth = get_depth(node.right) return l_depth - r_depth end --左旋 localfunctionll_rotate(parent,node) local son = node.right --父指针change if parent.root then parent.root = son elseif parent.left == node then parent.left = son else parent.right = son end local son_l = son.left --右子结点左指针change node.right.left = node --自身右指针change node.right = son_l --更新结点高度 update_depth(node) update_depth(son) end --右旋 localfunctionrr_rotate(parent,node) local son = node.left --父指针change if parent.root then parent.root = son elseif parent.left == node then parent.left = son else parent.right = son end local son_r = son.right --左子结点右指针change node.left.right = node --自身右指针change node.left = son_r --更新结点高度 update_depth(node) update_depth(son) end
--先左后右 localfunctionlr_rotate(parent,node) ll_rotate(node,node.left) rr_rotate(parent,node) end --先右后左 localfunctionrl_rotate(parent,node) rr_rotate(node,node.right) ll_rotate(parent,node) end --左边失衡 localfunctionbalance_left(parent,node) local l_balance = get_node_balance(node.left) if l_balance > 0then rr_rotate(parent,node) else lr_rotate(parent,node) end end --右边失衡 localfunctionbalance_right(parent,node) local r_balance = get_node_balance(node.right) if r_balance > 0then rl_rotate(parent,node) else ll_rotate(parent,node) end end --检测失衡并调整 localfunctionavl_node(parent,node) local balance = get_node_balance(node) if balance > 1then balance_left(parent,node) elseif balance < -1then balance_right(parent,node) end end
localfunctionadd_node(parent,node,k,v) if node.k == k then return end
if node.k > k then if node.left then add_node(node,node.left,k,v) else node.left = new_node(k,v) end else if node.right then add_node(node,node.right,k,v) else node.right = new_node(k,v) end end
localfunctiondel_node(parent,node,k,v) localfunctiondel(p,n,next) if p.root then ifnextthen p.root = next else p.root = nil end elseif p.left == n then ifnextthen p.left = next else p.left = nil end else ifnextthen p.right = next else p.right = nil end end end local res = false if node.k == k then ifnot node.left andnot node.right then del(parent,node) elseifnot node.left then del(parent,node,node.right) node.right = nil elseifnot node.right then del(parent,node,node.left) node.left = nil else --找node在中序遍历的前继节点或者后继节点 --我这里找前继,前继结点是左节点或者左节点的最右节点 local pp_node = nil local pre_node = node.left while pre_node.right do pp_node = pre_node pre_node = pre_node.right end del(parent,node,pre_node) pre_node.right = node.right node.right = nil if pp_node then pp_node.right = pre_node.left pre_node.left = node.left end
node.left = nil
local uplist = {pre_node} local unode = pre_node.left while unode do table.insert(uplist,unode) unode = unode.right end for i = #uplist,1,-1do update_depth(uplist[i]) end end returntrue end
if node.k > k then if node.left then res = del_node(node,node.left,k,v) end else if node.right then res = del_node(node,node.right,k,v) end end
if res then update_depth(node) avl_node(parent,node) end