Created
May 12, 2015 21:50
-
-
Save lvsl-deactivated/f6ed83bcf987d1cc58a2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
class _Node(object): | |
''' | |
Represents a node in a BST. | |
value is any object that implements __lt__(), enough for | |
sorting in python 3. | |
''' | |
def __init__(self, value, left=None, right=None): | |
if not hasattr(value, '__lt__'): | |
raise ValueError( | |
'value must implement __lt__()') | |
self.value = value | |
self.left = left | |
self.right = right | |
class BST(object): | |
''' | |
Simple BST | |
''' | |
def __init__(self, items=None): | |
self._root = None | |
if items: | |
for item in items: | |
self.insert(item) | |
def insert(self, value): | |
''' | |
Insert value into a BST | |
''' | |
if self._root is None: # first element | |
self._root = _Node(value) | |
else: | |
parent_node = self._find(self._root, value) | |
if value < parent_node.value: | |
parent_node.left = _Node(value) | |
else: | |
parent_node.right = _Node(value) | |
def __contains__(self, value): | |
''' | |
Find value in a BST, returns boolean | |
''' | |
if value is None: | |
raise ValueError('Value can not be empty') | |
if self._root is None: | |
return False | |
return self._root and self._find(self._root, value).value == value | |
def _find(self, root, value): | |
''' | |
Return leaf node after which value should be inserted | |
''' | |
if value is None: | |
raise ValueError('Value can not be empty') | |
if self._root is None: | |
raise ValueError('Can not find inserion point in an empty tree') | |
stack = [root] | |
while stack: | |
node = stack.pop() | |
if node.value == value: | |
return node | |
elif value < node.value: | |
child_node = node.left | |
else: | |
child_node = node.right | |
if child_node: | |
stack.append(child_node) | |
return node | |
def remove(self, value): | |
''' | |
Remove value from BST, return False if it's not there | |
''' | |
if self._root is None: | |
raise ValueError('Can not delete from empty tree') | |
q = [(None, self._root)] | |
while q: | |
parent, node = q.pop(0) | |
if value == node.value: | |
self._relink(parent, node) | |
return True | |
elif value < node.value: | |
next_node = node.left | |
else: | |
next_node = node.right | |
if next_node: | |
q.append((node, next_node)) | |
return False | |
def _relink(self, parent, node): | |
''' | |
Remove node from the tree and re-link | |
its children | |
''' | |
if node.left and node.right: # have both children | |
self._flip_with_parent(parent, node, node.right) | |
# link left subtree to the left most leaf of | |
# right subtree | |
leaf_to_link = self._find(node.right, node.left.value) | |
leaf_to_link.left = node.left | |
elif node.left: # right is missing | |
self._flip_with_parent(parent, node, node.left) | |
elif node.right: # left is missing | |
self._flip_with_parent(parent, node, node.right) | |
else: # both children are missing, leaf node | |
self._flip_with_parent(parent, node, None) | |
return True | |
def _flip_with_parent(self, parent, node, child): | |
if parent: | |
if node.value < parent.value: | |
parent.left = child | |
else: | |
parent.right = child | |
else: | |
self._root = child | |
# support iteration protocol | |
def __iter__(self): | |
if self._root is None: | |
raise StopIteration | |
yield from self._dfs(self._root) | |
raise StopIteration | |
def _dfs(self, node): | |
if node.left: | |
yield from self._dfs(node.left) | |
yield node.value | |
if node.right: | |
yield from self._dfs(node.right) | |
class Test(unittest.TestCase): | |
def test_iteration(self): | |
''' | |
Test creation and iteration | |
''' | |
bst = BST([3,2,-3,1]) | |
self.assertEqual([-3,1,2,3], list(bst)) | |
def test_empty_iteration(self): | |
''' | |
Test empty list | |
''' | |
bst = BST() | |
self.assertEqual([], list(bst)) | |
def test_signle_node_iter(self): | |
''' | |
Test single node iteration | |
''' | |
bst = BST([1]) | |
self.assertEqual([1,], list(bst)) | |
def test_contains(self): | |
bst = BST([4,-1,2]) | |
self.assertTrue(-1 in bst) | |
self.assertFalse(5 in bst) | |
def test_remove(self): | |
bst = BST([4,-1,2]) | |
bst.insert(3) | |
self.assertTrue(bst.remove(3)) | |
self.assertFalse(bst.remove(100)) | |
self.assertEqual(list(bst), [-1,2,4]) | |
def test_add_remove(self): | |
bst = BST([3,4,2]) | |
self.assertTrue(bst.remove(3)) | |
self.assertTrue(bst.remove(2)) | |
self.assertTrue(bst.remove(4)) | |
if __name__ == '__main__': | |
unittest.main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment