Created
January 1, 2021 04:44
-
-
Save lordpretzel/55bfee9a0655ea8d12b0705411ba68be to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Balanced BS Tree: AVL Tree\n", | |
"\n", | |
"## Agenda\n", | |
"\n", | |
"1. Motives\n", | |
"2. \"Balanced\" binary trees\n", | |
"3. Essential mechanic: rotation\n", | |
"4. Out-of-balance scenarios & rotation recipes\n", | |
"5. Generalized AVL rebalancing (insertion)\n", | |
"6. Rebalancing on removal" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 1. Motives" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class BSTree:\n", | |
" class Node:\n", | |
" def __init__(self, val, left=None, right=None):\n", | |
" self.val = val\n", | |
" self.left = left\n", | |
" self.right = right\n", | |
" \n", | |
" def __init__(self):\n", | |
" self.size = 0\n", | |
" self.root = None\n", | |
" \n", | |
" def add(self, val):\n", | |
" assert(val not in self)\n", | |
" def add_rec(node):\n", | |
" if not node:\n", | |
" return BSTree.Node(val)\n", | |
" elif val < node.val:\n", | |
" node.left = add_rec(node.left)\n", | |
" return node\n", | |
" else:\n", | |
" node.right = add_rec(node.right)\n", | |
" return node\n", | |
" self.root = add_rec(self.root)\n", | |
" self.size += 1\n", | |
" \n", | |
" def __contains__(self, val):\n", | |
" def contains_rec(node):\n", | |
" if not node:\n", | |
" return False\n", | |
" elif val < node.val:\n", | |
" return contains_rec(node.left)\n", | |
" elif val > node.val:\n", | |
" return contains_rec(node.right)\n", | |
" else:\n", | |
" return True\n", | |
" return contains_rec(self.root)\n", | |
" \n", | |
" def __len__(self):\n", | |
" return self.size\n", | |
" \n", | |
" def __delitem__(self, val):\n", | |
" assert(val in self)\n", | |
" def delitem_rec(node):\n", | |
" if val < node.val:\n", | |
" node.left = delitem_rec(node.left)\n", | |
" return node\n", | |
" elif val > node.val:\n", | |
" node.right = delitem_rec(node.right)\n", | |
" return node\n", | |
" else:\n", | |
" if not node.left and not node.right:\n", | |
" return None\n", | |
" elif node.left and not node.right:\n", | |
" return node.left\n", | |
" elif node.right and not node.left:\n", | |
" return node.right\n", | |
" else:\n", | |
" # remove the largest value from the left subtree as a replacement\n", | |
" # for the root value of this tree\n", | |
" t = node.left\n", | |
" if not t.right:\n", | |
" node.val = t.val\n", | |
" node.left = t.left\n", | |
" else:\n", | |
" n = t\n", | |
" while n.right.right:\n", | |
" n = n.right\n", | |
" t = n.right\n", | |
" n.right = t.left\n", | |
" node.val = t.val\n", | |
" return node\n", | |
" \n", | |
" self.root = delitem_rec(self.root)\n", | |
" self.size -= 1\n", | |
" \n", | |
" def pprint(self, width=64):\n", | |
" \"\"\"Attempts to pretty-print this tree's contents.\"\"\"\n", | |
" height = self.height()\n", | |
" nodes = [(self.root, 0)]\n", | |
" prev_level = 0\n", | |
" repr_str = ''\n", | |
" while nodes:\n", | |
" n,level = nodes.pop(0)\n", | |
" if prev_level != level:\n", | |
" prev_level = level\n", | |
" repr_str += '\\n'\n", | |
" if not n:\n", | |
" if level < height-1:\n", | |
" nodes.extend([(None, level+1), (None, level+1)])\n", | |
" repr_str += '{val:^{width}}'.format(val='-', width=width//2**level)\n", | |
" elif n:\n", | |
" if n.left or level < height-1:\n", | |
" nodes.append((n.left, level+1))\n", | |
" if n.right or level < height-1:\n", | |
" nodes.append((n.right, level+1))\n", | |
" repr_str += '{val:^{width}}'.format(val=n.val, width=width//2**level)\n", | |
" print(repr_str)\n", | |
" \n", | |
" def height(self):\n", | |
" \"\"\"Returns the height of the longest branch of the tree.\"\"\"\n", | |
" def height_rec(t):\n", | |
" if not t:\n", | |
" return 0\n", | |
" else:\n", | |
" return max(1+height_rec(t.left), 1+height_rec(t.right))\n", | |
" return height_rec(self.root)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"t = BSTree()\n", | |
"for x in range(6):\n", | |
" t.add(x)\n", | |
"t.pprint()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import sys\n", | |
"sys.setrecursionlimit(100)\n", | |
"\n", | |
"t = BSTree()\n", | |
"for x in range(100):\n", | |
" t.add(x)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 2. \"Balanced\" binary trees" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 3. Essential mechanic: rotation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class AVLTree(BSTree):\n", | |
" class Node:\n", | |
" def __init__(self, val, left=None, right=None):\n", | |
" self.val = val\n", | |
" self.left = left\n", | |
" self.right = right\n", | |
"\n", | |
" def rotate_right(self):\n", | |
" pass\n", | |
" \n", | |
" def add(self, val):\n", | |
" assert(val not in self)\n", | |
" def add_rec(node):\n", | |
" if not node:\n", | |
" return AVLTree.Node(val)\n", | |
" elif val < node.val:\n", | |
" node.left = add_rec(node.left)\n", | |
" return node\n", | |
" else:\n", | |
" node.right = add_rec(node.right)\n", | |
" return node\n", | |
" self.root = add_rec(self.root)\n", | |
" self.size += 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"t = AVLTree()\n", | |
"for x in range(6, 0, -1):\n", | |
" t.add(x)\n", | |
"t.pprint()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"t.root.rotate_right()\n", | |
"t.pprint()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"t.root.rotate_right()\n", | |
"t.pprint()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"t.root.left.rotate_right()\n", | |
"t.pprint()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class AVLTree(BSTree):\n", | |
" class Node:\n", | |
" def __init__(self, val, left=None, right=None):\n", | |
" self.val = val\n", | |
" self.left = left\n", | |
" self.right = right\n", | |
"\n", | |
" def rotate_right(self):\n", | |
" n = self.left\n", | |
" self.val, n.val = n.val, self.val\n", | |
" self.left, n.left, self.right, n.right = n.left, n.right, n, self.right\n", | |
" \n", | |
" @staticmethod\n", | |
" def height(n):\n", | |
" if not n:\n", | |
" return 0\n", | |
" else:\n", | |
" return max(1+AVLTree.Node.height(n.left), 1+AVLTree.Node.height(n.right))\n", | |
" \n", | |
" def add(self, val):\n", | |
" assert(val not in self)\n", | |
" def add_rec(node):\n", | |
" if not node:\n", | |
" return AVLTree.Node(val)\n", | |
" elif val < node.val:\n", | |
" node.left = add_rec(node.left)\n", | |
" else:\n", | |
" node.right = add_rec(node.right)\n", | |
" # detect and fix imbalance\n", | |
" self.root = add_rec(self.root)\n", | |
" self.size += 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"val = 50\n", | |
"t = AVLTree()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# (evaluate multiple times with ctrl-enter)\n", | |
"t.add(val)\n", | |
"val -= 1\n", | |
"t.pprint()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"val = 0\n", | |
"t = AVLTree()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# (evaluate multiple times with ctrl-enter)\n", | |
"t.add(val)\n", | |
"val += 1\n", | |
"t.pprint()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 4. \"Out-of-balance\" scenarios & rotation recipes" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# \"left-left\" scenario\n", | |
"t = BSTree()\n", | |
"for x in [3, 2, 1]:\n", | |
" t.add(x)\n", | |
"t.pprint()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# \"left-right\" scenario\n", | |
"t = BSTree()\n", | |
"for x in [3, 1, 2]:\n", | |
" t.add(x)\n", | |
"t.pprint()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# \"right-right\" scenario\n", | |
"t = BSTree()\n", | |
"for x in [1, 2, 3]:\n", | |
" t.add(x)\n", | |
"t.pprint()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# \"right-left\" scenario\n", | |
"t = BSTree()\n", | |
"for x in [1, 3, 2]:\n", | |
" t.add(x)\n", | |
"t.pprint()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 5. Generalized AVL rebalancing (insertion)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class AVLTree(BSTree):\n", | |
" class Node:\n", | |
" def __init__(self, val, left=None, right=None):\n", | |
" self.val = val\n", | |
" self.left = left\n", | |
" self.right = right\n", | |
"\n", | |
" def rotate_right(self):\n", | |
" n = self.left\n", | |
" self.val, n.val = n.val, self.val\n", | |
" self.left, n.left, self.right, n.right = n.left, n.right, n, self.right\n", | |
" \n", | |
" def rotate_left(self):\n", | |
" pass\n", | |
" \n", | |
" @staticmethod\n", | |
" def height(n):\n", | |
" if not n:\n", | |
" return 0\n", | |
" else:\n", | |
" return max(1+AVLTree.Node.height(n.left), 1+AVLTree.Node.height(n.right))\n", | |
" \n", | |
" @staticmethod\n", | |
" def rebalance(t):\n", | |
" if AVLTree.Node.height(t.left) > AVLTree.Node.height(t.right):\n", | |
" if AVLTree.Node.height(t.left.left) >= AVLTree.Node.height(t.left.right):\n", | |
" # left-left\n", | |
" print('left-left imbalance detected')\n", | |
" # fix?\n", | |
" else:\n", | |
" # left-right\n", | |
" print('left-right imbalance detected')\n", | |
" # fix?\n", | |
" else:\n", | |
" # right branch imbalance tests needed\n", | |
" pass\n", | |
" \n", | |
" def add(self, val):\n", | |
" assert(val not in self)\n", | |
" def add_rec(node):\n", | |
" if not node:\n", | |
" return AVLTree.Node(val)\n", | |
" elif val < node.val:\n", | |
" node.left = add_rec(node.left)\n", | |
" else:\n", | |
" node.right = add_rec(node.right)\n", | |
" if False: # detect imbalance\n", | |
" AVLTree.rebalance(node)\n", | |
" return node\n", | |
" self.root = add_rec(self.root)\n", | |
" self.size += 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"t = AVLTree()\n", | |
"for x in [10, 5, 1]:\n", | |
" t.add(x)\n", | |
"t.pprint()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# broken!\n", | |
"t = AVLTree()\n", | |
"for x in [10, 5, 1, 2, 3]:\n", | |
" t.add(x)\n", | |
"t.pprint()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 5. Rebalancing on removal" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class AVLTree(BSTree):\n", | |
" class Node:\n", | |
" def __init__(self, val, left=None, right=None):\n", | |
" self.val = val\n", | |
" self.left = left\n", | |
" self.right = right\n", | |
"\n", | |
" def rotate_right(self):\n", | |
" n = self.left\n", | |
" self.val, n.val = n.val, self.val\n", | |
" self.left, n.left, self.right, n.right = n.left, n.right, n, self.right\n", | |
" \n", | |
" def rotate_left(self):\n", | |
" pass\n", | |
" \n", | |
" @staticmethod\n", | |
" def height(n):\n", | |
" if not n:\n", | |
" return 0\n", | |
" else:\n", | |
" return max(1+AVLTree.Node.height(n.left), 1+AVLTree.Node.height(n.right))\n", | |
" \n", | |
" @staticmethod\n", | |
" def rebalance(t):\n", | |
" if AVLTree.Node.height(t.left) > AVLTree.Node.height(t.right):\n", | |
" if AVLTree.Node.height(t.left.left) >= AVLTree.Node.height(t.left.right):\n", | |
" # left-left\n", | |
" print('left-left imbalance detected')\n", | |
" t.rotate_right()\n", | |
" else:\n", | |
" # left-right\n", | |
" print('left-right imbalance detected')\n", | |
" t.left.rotate_left()\n", | |
" t.rotate_right()\n", | |
" else:\n", | |
" pass\n", | |
" \n", | |
" def add(self, val):\n", | |
" assert(val not in self)\n", | |
" def add_rec(node):\n", | |
" if not node:\n", | |
" return AVLTree.Node(val)\n", | |
" elif val < node.val:\n", | |
" node.left = add_rec(node.left)\n", | |
" else:\n", | |
" node.right = add_rec(node.right)\n", | |
" if abs(AVLTree.Node.height(node.left) - AVLTree.Node.height(node.right)) >= 2:\n", | |
" AVLTree.rebalance(node)\n", | |
" return node\n", | |
" self.root = add_rec(self.root)\n", | |
" self.size += 1\n", | |
" \n", | |
" def __delitem__(self, val):\n", | |
" assert(val in self)\n", | |
" def delitem_rec(node):\n", | |
" if val < node.val:\n", | |
" node.left = delitem_rec(node.left)\n", | |
" elif val > node.val:\n", | |
" node.right = delitem_rec(node.right)\n", | |
" else:\n", | |
" if not node.left and not node.right:\n", | |
" return None\n", | |
" elif node.left and not node.right:\n", | |
" return node.left\n", | |
" elif node.right and not node.left:\n", | |
" return node.right\n", | |
" else:\n", | |
" # remove the largest value from the left subtree (t) as a replacement\n", | |
" # for the root value of this tree\n", | |
" t = node.left\n", | |
" if not t.right:\n", | |
" node.val = t.val\n", | |
" node.left = t.left \n", | |
" else:\n", | |
" n = t\n", | |
" while n.right.right:\n", | |
" n = n.right\n", | |
" t = n.right\n", | |
" n.right = t.left\n", | |
" node.val = t.val\n", | |
" # detect and fix imbalance\n", | |
" return node\n", | |
" \n", | |
" self.root = delitem_rec(self.root)\n", | |
" self.size -= 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"t = AVLTree()\n", | |
"for x in [10, 5, 15, 2]:\n", | |
" t.add(x)\n", | |
"t.pprint()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"del t[15]\n", | |
"t.pprint()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"t = AVLTree()\n", | |
"for x in range(31, 0, -1):\n", | |
" t.add(x)\n", | |
"t.pprint()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"del t[15]\n", | |
"del t[14]\n", | |
"t.pprint()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"del t[16]\n", | |
"t.pprint()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment