-
-
Save DeaconDesperado/6675981 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Node(object): | |
def __init__(self, key, parent=None): | |
self.key = key | |
self.parent = parent | |
self.left = None | |
self.right = None | |
def has_children(self): return len(self) > 0 | |
def delete(self, node): | |
assert isinstance(node, Node) | |
if node is self.right: self.right = None | |
if node is self.left: self.left = None | |
def replace_on_parent(self, node): | |
assert isinstance(node, Node) | |
assert self.parent is not None | |
if node is node.parent.left: | |
node.parent.left = node | |
else: | |
node.parent.right = node | |
def __len__(self): | |
c = 0 | |
if self.left is not None: | |
c += len(self.left) | |
c += 1 | |
if self.right is not None: | |
c += len(self.right) | |
c += 1 | |
return c | |
def __nonzero__(self): | |
return self.key is not None | |
def __lt__(self, key): return self.key < key | |
def __gt__(self, key): return self.key > key | |
def __eq__(self, key): return self.key == key | |
def __str__(self): return '%s' % self.key | |
def __repr__(self): | |
if self.parent: parent = self.parent.key | |
else: parent = None | |
if self.left: left = self.left.key | |
else: left = None | |
if self.right: right = self.right.key | |
else: right = None | |
return '%r' % { | |
'key': self.key, | |
'parent': parent, | |
'left': left, | |
'right': right | |
} | |
class Tree(object): | |
def __init__(self): | |
self.root = None | |
self.n = 0 | |
def __len__(self): return self.n | |
def __nonzero__(self): return True | |
def insert(self, key): | |
assert key is not None | |
if self.root is None: | |
self.root = Node(key) | |
self.n = 1 | |
return True | |
boolean = self._insert(self.root, key) | |
if boolean: self.n += 1 | |
return boolean | |
def _insert(self, node, key): | |
if node == key: return False | |
if node > key: | |
if node.left is None: | |
node.left = Node(key, node) | |
return True | |
return self._insert(node.left, key) | |
else: | |
if node.right is None: | |
node.right = Node(key, node) | |
return True | |
return self._insert(node.right, key) | |
def get(self, key): | |
if key is None or self.root is None: return None | |
node = self._get(self.root, key) | |
if node is not None: return node.key | |
return None | |
def _get(self, node, key): | |
if node is None: return None | |
if node == key: return node | |
if node > key: | |
return self._get(node.left, key) | |
else: | |
return self._get(node.right, key) | |
def max(self): | |
node = self._max(self.root) | |
if node is not None: return node.key | |
return None | |
def _max(self, node): | |
if node.right is None: return node | |
return self._max(node.right) | |
def min(self): | |
node = self._min(self.root) | |
if node is not None: return node.key | |
return None | |
def _min(self, node): | |
if node.left is None: return node | |
return self._min(node.left) | |
def sorted(self): | |
# preallocate array size, yo | |
arr = [None] * len(self) | |
self._sorted(self.root, arr, 0) | |
return arr | |
def _sorted(self, node, arr, i): | |
""" | |
@param node : obvious | |
@param arr : array to at value to | |
@param i : index to place value at | |
@return int : the index of the next value for placement | |
""" | |
if node.left is not None: | |
# reset the next index placement | |
# via recursion for it's parent's index placement | |
i = self._sorted(node.left, arr, i) | |
# set the node's value at the index given and increment | |
# increment i to the next index | |
arr[i] = node.key | |
i += 1 | |
# if we have right node, recurse. | |
# otherwise, return the next index placement | |
if node.right is not None: return self._sorted(node.right, arr, i) | |
else: return i | |
def pred(self, key): | |
node = self._get(self.root, key) | |
if node is None: return None | |
pred = self._pred(node, key) | |
if pred is not None: return pred.key | |
return None | |
def _pred(self, node, key): | |
if node is None: return None | |
# a node with a left node, can find it's | |
# pred by called max on it's left node | |
if node.left: return self._max(node.left) | |
# If no left node and is root, then we are done | |
if node is self.root: return None | |
# Nodes that are right nodes, can find pred just by | |
# look at it's parent | |
if node.parent.right is node: return node.parent | |
# At this point, we know the node is not root or a left node | |
# and has no left child. We must go up the parent chain, set | |
# our current node to the parent and check if it is a right node | |
# if the node is a right node, then it's parent is the successor | |
while node.parent and node.parent.left is node: | |
node = node.parent | |
if node is self.root: return None | |
# pred is found | |
if node.parent.right is node: return node.parent | |
return None | |
def successor(self, key): | |
node = self._get(self.root, key) | |
if node is None: return None | |
suc = self._successor(node, key) | |
if suc is not None: return suc.key | |
return None | |
def _successor(self, node, key): | |
assert isinstance(node, Node) | |
if node.right: return self._max(node) | |
if node is self.root: return None | |
if node is node.parent.left: return node.parent | |
while node.parent and node is node.parent.right: | |
node = node.parent | |
if node is self.root: return None | |
if node is node.parent.left: return node.parent | |
return None | |
def delete(self, key): | |
assert key is not None | |
node = self._get(self.root, key) | |
self.n-=1 | |
if node is not None: | |
return self._delete(node) | |
return False | |
def _delete(self, node): | |
assert isinstance(node, Node) | |
if len(node) is 0: | |
if node is self.root: | |
self.root = None | |
else: | |
node.parent.delete(node) | |
elif len(node) is 1: | |
child = node.left or node.right | |
if node is self.root: | |
self.root = child | |
else: | |
node.replace_on_parent(child) | |
else: | |
# node has more than 1 child | |
# find a larger key that is a succesor from the current node | |
# that contains no children to replace with this node | |
successor = node | |
while successor.has_children(): | |
successor = self._successor(successor, successor.key) | |
if node is self.root: | |
self.root = successor | |
successor.parent.delete(successor) | |
else: | |
node.replace_on_parent(successor) | |
return True | |
def __repr__(self): | |
sorted = map(lambda x: self._get(self.root, x), self.sorted()) | |
return '%r' % sorted | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from bin_tree import Tree | |
import unittest | |
from random import shuffle, seed | |
def is_sorted(arr): | |
assert isinstance(arr, list) | |
i = 0 | |
j = 1 | |
while j < len(arr): | |
if arr[j] < arr[i]: return False | |
i += 1 | |
j += 1 | |
return True | |
class Test(unittest.TestCase): | |
def setUp(self): | |
self.keys = range(100) | |
seed(len(self.keys)) | |
shuffle(self.keys) | |
self.balanced = [5, 3, 7, 2, 4, 6, 8] | |
def test_insert(self): | |
tree = Tree() | |
for val in self.keys: | |
self.assertTrue(tree.insert(val)) | |
self.assertEqual(len(self.keys), len(tree)) | |
for val in self.keys: | |
self.assertFalse(tree.insert(val)) | |
# todo ensure proper values get set for parent, left, right | |
def test_get(self): | |
tree = Tree() | |
for val in self.keys: | |
tree.insert(val) | |
for val in self.keys: | |
self.assertEqual(tree.get(val), val) | |
self.assertIsNone(tree.get(-1)) | |
self.assertIsNone(tree.get(len(self.keys))) | |
def test_max(self): | |
tree = Tree() | |
for val in self.keys: tree.insert(val) | |
self.assertEqual(tree.max(), len(self.keys) - 1) | |
def test_min(self): | |
tree = Tree() | |
for val in self.keys: tree.insert(val) | |
self.assertEqual(tree.min(), 0) | |
def test_is_sorted_test(self): | |
arr = range(5) | |
self.assertTrue(is_sorted(arr)) | |
def test_sorted(self): | |
tree = Tree() | |
for val in self.keys: tree.insert(val) | |
self.assertTrue(is_sorted(tree.sorted())) | |
def test_pred(self): | |
tree = Tree() | |
for val in self.balanced: tree.insert(val) | |
self.assertIsNone(tree.pred(min(self.balanced))) | |
i = min(self.balanced) + 1 | |
while i < len(self.balanced): | |
self.assertEqual(tree.pred(i), i - 1) | |
i += 1 | |
def test_successor(self): | |
tree = Tree() | |
for val in self.balanced: tree.insert(val) | |
self.assertIsNone(tree.successor(max(self.balanced))) | |
i = max(self.balanced) - 1 | |
while i < len(self.balanced): | |
self.assertEqual(tree.successor(i), i + 1) | |
i -= 1 | |
def test_delete_by_index(self): | |
# Ensure deleting any node from a full tree works | |
arr = self.keys | |
for i in arr: | |
tree = Tree() | |
for val in arr: tree.insert(val) | |
self.assertTrue(tree.delete(i), 'val %s should have been deleted' % i) | |
def test_delete_falsehood(self): | |
# Test values that don't exist return false | |
arr = self.keys | |
tree = Tree() | |
for val in arr: tree.insert(val) | |
self.assertFalse(tree.delete(min(arr) - 1), 'should not find %s' % 1) | |
self.assertFalse(tree.delete(max(arr) + 1), 'should not find %s' % 9) | |
def test_all_delete(self): | |
# delte all values should make tree size 0 | |
arr = self.keys | |
tree = Tree() | |
for val in arr: tree.insert(val) | |
for val in arr: | |
tree.delete(val) | |
self.assertEqual(len(tree), 0) | |
unittest.main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment