Skip to content

Instantly share code, notes, and snippets.

@DeaconDesperado
Forked from landau/bin_tree.py
Last active December 23, 2015 18:29
Show Gist options
  • Save DeaconDesperado/6675981 to your computer and use it in GitHub Desktop.
Save DeaconDesperado/6675981 to your computer and use it in GitHub Desktop.
class Node(object):
def __init__(self, key, parent=None):
self.key = key
self.parent = parent
self.left = None
self.right = None
def has_children(self): return len(self) > 0
def delete(self, node):
assert isinstance(node, Node)
if node is self.right: self.right = None
if node is self.left: self.left = None
def replace_on_parent(self, node):
assert isinstance(node, Node)
assert self.parent is not None
if node is node.parent.left:
node.parent.left = node
else:
node.parent.right = node
def __len__(self):
c = 0
if self.left is not None:
c += len(self.left)
c += 1
if self.right is not None:
c += len(self.right)
c += 1
return c
def __nonzero__(self):
return self.key is not None
def __lt__(self, key): return self.key < key
def __gt__(self, key): return self.key > key
def __eq__(self, key): return self.key == key
def __str__(self): return '%s' % self.key
def __repr__(self):
if self.parent: parent = self.parent.key
else: parent = None
if self.left: left = self.left.key
else: left = None
if self.right: right = self.right.key
else: right = None
return '%r' % {
'key': self.key,
'parent': parent,
'left': left,
'right': right
}
class Tree(object):
def __init__(self):
self.root = None
self.n = 0
def __len__(self): return self.n
def __nonzero__(self): return True
def insert(self, key):
assert key is not None
if self.root is None:
self.root = Node(key)
self.n = 1
return True
boolean = self._insert(self.root, key)
if boolean: self.n += 1
return boolean
def _insert(self, node, key):
if node == key: return False
if node > key:
if node.left is None:
node.left = Node(key, node)
return True
return self._insert(node.left, key)
else:
if node.right is None:
node.right = Node(key, node)
return True
return self._insert(node.right, key)
def get(self, key):
if key is None or self.root is None: return None
node = self._get(self.root, key)
if node is not None: return node.key
return None
def _get(self, node, key):
if node is None: return None
if node == key: return node
if node > key:
return self._get(node.left, key)
else:
return self._get(node.right, key)
def max(self):
node = self._max(self.root)
if node is not None: return node.key
return None
def _max(self, node):
if node.right is None: return node
return self._max(node.right)
def min(self):
node = self._min(self.root)
if node is not None: return node.key
return None
def _min(self, node):
if node.left is None: return node
return self._min(node.left)
def sorted(self):
# preallocate array size, yo
arr = [None] * len(self)
self._sorted(self.root, arr, 0)
return arr
def _sorted(self, node, arr, i):
"""
@param node : obvious
@param arr : array to at value to
@param i : index to place value at
@return int : the index of the next value for placement
"""
if node.left is not None:
# reset the next index placement
# via recursion for it's parent's index placement
i = self._sorted(node.left, arr, i)
# set the node's value at the index given and increment
# increment i to the next index
arr[i] = node.key
i += 1
# if we have right node, recurse.
# otherwise, return the next index placement
if node.right is not None: return self._sorted(node.right, arr, i)
else: return i
def pred(self, key):
node = self._get(self.root, key)
if node is None: return None
pred = self._pred(node, key)
if pred is not None: return pred.key
return None
def _pred(self, node, key):
if node is None: return None
# a node with a left node, can find it's
# pred by called max on it's left node
if node.left: return self._max(node.left)
# If no left node and is root, then we are done
if node is self.root: return None
# Nodes that are right nodes, can find pred just by
# look at it's parent
if node.parent.right is node: return node.parent
# At this point, we know the node is not root or a left node
# and has no left child. We must go up the parent chain, set
# our current node to the parent and check if it is a right node
# if the node is a right node, then it's parent is the successor
while node.parent and node.parent.left is node:
node = node.parent
if node is self.root: return None
# pred is found
if node.parent.right is node: return node.parent
return None
def successor(self, key):
node = self._get(self.root, key)
if node is None: return None
suc = self._successor(node, key)
if suc is not None: return suc.key
return None
def _successor(self, node, key):
assert isinstance(node, Node)
if node.right: return self._max(node)
if node is self.root: return None
if node is node.parent.left: return node.parent
while node.parent and node is node.parent.right:
node = node.parent
if node is self.root: return None
if node is node.parent.left: return node.parent
return None
def delete(self, key):
assert key is not None
node = self._get(self.root, key)
self.n-=1
if node is not None:
return self._delete(node)
return False
def _delete(self, node):
assert isinstance(node, Node)
if len(node) is 0:
if node is self.root:
self.root = None
else:
node.parent.delete(node)
elif len(node) is 1:
child = node.left or node.right
if node is self.root:
self.root = child
else:
node.replace_on_parent(child)
else:
# node has more than 1 child
# find a larger key that is a succesor from the current node
# that contains no children to replace with this node
successor = node
while successor.has_children():
successor = self._successor(successor, successor.key)
if node is self.root:
self.root = successor
successor.parent.delete(successor)
else:
node.replace_on_parent(successor)
return True
def __repr__(self):
sorted = map(lambda x: self._get(self.root, x), self.sorted())
return '%r' % sorted
from bin_tree import Tree
import unittest
from random import shuffle, seed
def is_sorted(arr):
assert isinstance(arr, list)
i = 0
j = 1
while j < len(arr):
if arr[j] < arr[i]: return False
i += 1
j += 1
return True
class Test(unittest.TestCase):
def setUp(self):
self.keys = range(100)
seed(len(self.keys))
shuffle(self.keys)
self.balanced = [5, 3, 7, 2, 4, 6, 8]
def test_insert(self):
tree = Tree()
for val in self.keys:
self.assertTrue(tree.insert(val))
self.assertEqual(len(self.keys), len(tree))
for val in self.keys:
self.assertFalse(tree.insert(val))
# todo ensure proper values get set for parent, left, right
def test_get(self):
tree = Tree()
for val in self.keys:
tree.insert(val)
for val in self.keys:
self.assertEqual(tree.get(val), val)
self.assertIsNone(tree.get(-1))
self.assertIsNone(tree.get(len(self.keys)))
def test_max(self):
tree = Tree()
for val in self.keys: tree.insert(val)
self.assertEqual(tree.max(), len(self.keys) - 1)
def test_min(self):
tree = Tree()
for val in self.keys: tree.insert(val)
self.assertEqual(tree.min(), 0)
def test_is_sorted_test(self):
arr = range(5)
self.assertTrue(is_sorted(arr))
def test_sorted(self):
tree = Tree()
for val in self.keys: tree.insert(val)
self.assertTrue(is_sorted(tree.sorted()))
def test_pred(self):
tree = Tree()
for val in self.balanced: tree.insert(val)
self.assertIsNone(tree.pred(min(self.balanced)))
i = min(self.balanced) + 1
while i < len(self.balanced):
self.assertEqual(tree.pred(i), i - 1)
i += 1
def test_successor(self):
tree = Tree()
for val in self.balanced: tree.insert(val)
self.assertIsNone(tree.successor(max(self.balanced)))
i = max(self.balanced) - 1
while i < len(self.balanced):
self.assertEqual(tree.successor(i), i + 1)
i -= 1
def test_delete_by_index(self):
# Ensure deleting any node from a full tree works
arr = self.keys
for i in arr:
tree = Tree()
for val in arr: tree.insert(val)
self.assertTrue(tree.delete(i), 'val %s should have been deleted' % i)
def test_delete_falsehood(self):
# Test values that don't exist return false
arr = self.keys
tree = Tree()
for val in arr: tree.insert(val)
self.assertFalse(tree.delete(min(arr) - 1), 'should not find %s' % 1)
self.assertFalse(tree.delete(max(arr) + 1), 'should not find %s' % 9)
def test_all_delete(self):
# delte all values should make tree size 0
arr = self.keys
tree = Tree()
for val in arr: tree.insert(val)
for val in arr:
tree.delete(val)
self.assertEqual(len(tree), 0)
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment