""" File: avl_tree.py Created Time: 2022-12-20 Author: a16su (lpluls001@gmail.com) """ import sys from pathlib import Path sys.path.append(str(Path(__file__).parent.parent)) from modules import TreeNode, print_tree class AVLTree: """AVL tree""" def __init__(self): """Constructor""" self._root = None def get_root(self) -> TreeNode | None: """Get binary tree root node""" return self._root def height(self, node: TreeNode | None) -> int: """Get node height""" # Empty node height is -1, leaf node height is 0 if node is not None: return node.height return -1 def update_height(self, node: TreeNode | None): """Update node height""" # Node height equals the height of the tallest subtree + 1 node.height = max([self.height(node.left), self.height(node.right)]) + 1 def balance_factor(self, node: TreeNode | None) -> int: """Get balance factor""" # Empty node balance factor is 0 if node is None: return 0 # Node balance factor = left subtree height - right subtree height return self.height(node.left) - self.height(node.right) def right_rotate(self, node: TreeNode | None) -> TreeNode | None: """Right rotation operation""" child = node.left grand_child = child.right # Rotate node to the right around child child.right = node node.left = grand_child # Update node height self.update_height(node) self.update_height(child) # Return the root of the subtree after rotation return child def left_rotate(self, node: TreeNode | None) -> TreeNode | None: """Left rotation operation""" child = node.right grand_child = child.left # Rotate node to the left around child child.left = node node.right = grand_child # Update node height self.update_height(node) self.update_height(child) # Return the root of the subtree after rotation return child def rotate(self, node: TreeNode | None) -> TreeNode | None: """Perform rotation operation to restore balance to the subtree""" # Get the balance factor of node balance_factor = self.balance_factor(node) # Left-leaning tree if balance_factor > 1: if self.balance_factor(node.left) >= 0: # Right rotation return self.right_rotate(node) else: # First left rotation then right rotation node.left = self.left_rotate(node.left) return self.right_rotate(node) # Right-leaning tree elif balance_factor < -1: if self.balance_factor(node.right) <= 0: # Left rotation return self.left_rotate(node) else: # First right rotation then left rotation node.right = self.right_rotate(node.right) return self.left_rotate(node) # Balanced tree, no rotation needed, return return node def insert(self, val): """Insert node""" self._root = self.insert_helper(self._root, val) def insert_helper(self, node: TreeNode | None, val: int) -> TreeNode: """Recursively insert node (helper method)""" if node is None: return TreeNode(val) # 1. Find insertion position and insert node if val < node.val: node.left = self.insert_helper(node.left, val) elif val > node.val: node.right = self.insert_helper(node.right, val) else: # Do not insert duplicate nodes, return return node # Update node height self.update_height(node) # 2. Perform rotation operation to restore balance to the subtree return self.rotate(node) def remove(self, val: int): """Remove node""" self._root = self.remove_helper(self._root, val) def remove_helper(self, node: TreeNode | None, val: int) -> TreeNode | None: """Recursively remove node (helper method)""" if node is None: return None # 1. Find and remove the node if val < node.val: node.left = self.remove_helper(node.left, val) elif val > node.val: node.right = self.remove_helper(node.right, val) else: if node.left is None or node.right is None: child = node.left or node.right # Number of child nodes = 0, remove node and return if child is None: return None # Number of child nodes = 1, remove node else: node = child else: # Number of child nodes = 2, remove the next node in in-order traversal and replace the current node with it temp = node.right while temp.left is not None: temp = temp.left node.right = self.remove_helper(node.right, temp.val) node.val = temp.val # Update node height self.update_height(node) # 2. Perform rotation operation to restore balance to the subtree return self.rotate(node) def search(self, val: int) -> TreeNode | None: """Search node""" cur = self._root # Loop find, break after passing leaf nodes while cur is not None: # Target node is in cur's right subtree if cur.val < val: cur = cur.right # Target node is in cur's left subtree elif cur.val > val: cur = cur.left # Found target node, break loop else: break # Return target node return cur """Driver Code""" if __name__ == "__main__": def test_insert(tree: AVLTree, val: int): tree.insert(val) print("\nInsert node {} after, AVL tree is".format(val)) print_tree(tree.get_root()) def test_remove(tree: AVLTree, val: int): tree.remove(val) print("\nRemove node {} after, AVL tree is".format(val)) print_tree(tree.get_root()) # Initialize empty AVL tree avl_tree = AVLTree() # Insert node # Notice how the AVL tree maintains balance after inserting nodes for val in [1, 2, 3, 4, 5, 8, 7, 9, 10, 6]: test_insert(avl_tree, val) # Insert duplicate node test_insert(avl_tree, 7) # Remove node # Notice how the AVL tree maintains balance after removing nodes test_remove(avl_tree, 8) # Remove node with degree 0 test_remove(avl_tree, 5) # Remove node with degree 1 test_remove(avl_tree, 4) # Remove node with degree 2 result_node = avl_tree.search(7) print("\nFound node object is {}, node value = {}".format(result_node, result_node.val))