Skip to content

Instantly share code, notes, and snippets.

@rohit-jamuar
Last active August 29, 2015 14:18
Show Gist options
  • Select an option

  • Save rohit-jamuar/5790915417bdd557745c to your computer and use it in GitHub Desktop.

Select an option

Save rohit-jamuar/5790915417bdd557745c to your computer and use it in GitHub Desktop.
Binary Tree with a print generator
class BiNode(object):
'''
A binary tree node
'''
def __init__(self, val):
self.value = val
self.left, self.right = None, None
class BinaryTree(object):
'''
A Binary Tree
'''
def __init__(self):
self.root = None
self.node_gen = None
self.chosen_print_mode = None
def insert(self, k):
'''
Inserts a node into binary tree.
'''
if not self.root:
self.root = BiNode(k)
else:
self._insert(self.root, k)
def _insert(self, node, val):
'''
Recursively inserts a node into binary tree - this is
a helper method. This method is also utilized by the delete()
method for appropriately linking children of to-be-deleted node
with it's parent.
'''
if not val:
return True
elif not node.left:
node.left = val if isinstance(val, BiNode) else BiNode(val)
return True
elif not node.right:
node.right = val if isinstance(val, BiNode) else BiNode(val)
return True
return self._insert(node.left, val) or\
self._insert(node.right, val)
def search(self, k):
'''
Search a node in BTree with value k. Returns the node,
if found, else None.
'''
return self._search(self.root, k)
def _search(self, node, val):
'''
Helper function which recursively looks for a node with value 'val'.
'''
if not node:
return
if node.value == val:
return node
return self._search(node.left, val) or\
self._search(node.right, val)
def search_parent(self, k):
'''
Search a node's parent in BTree whose value is k. Returns the parent,
if found, else None.
'''
return self._search_parent(self.root, k)
def _search_parent(self, node, k):
'''
Recursively searches for the parent of node with value 'k'.
'''
if not node:
return
if (node.left and node.left.value == k) or\
(node.right and node.right.value == k):
return node
return self._search_parent(node.left, k) or\
self._search_parent(node.right, k)
def delete(self, k):
'''
Deletes a node with value k.
'''
if self.root:
to_delete = self.search(k)
if to_delete:
if to_delete == self.root:
if self.root.left:
self._insert(self.root.left, self.root.right)
self.root = self.root.left
elif self.root.right:
self._insert(self.root.right, self.root.left)
self.root = self.root.right
else:
self.root = None
else:
parent_node = self.search_parent(k)
if parent_node.left and parent_node.left == to_delete:
if to_delete.left:
self._insert(to_delete.left, to_delete.right)
parent_node.left = to_delete.left
elif to_delete.right:
self._insert(to_delete.right, to_delete.left)
parent_node.left = to_delete.right
else:
parent_node.left = None
elif parent_node.right and parent_node.right == to_delete:
if to_delete.left:
self._insert(to_delete.left, to_delete.right)
parent_node.right = to_delete.left
elif to_delete.right:
self._insert(to_delete.right, to_delete.left)
parent_node.right = to_delete.right
else:
parent_node.right = None
def print_tree(self, mode='bfs', refresh=False):
'''
A 'non-exhausting' helper function which prints elements
of the tree in BFS / DFS order.
'''
permissible_modes = {'bfs': self._get_node_bfs,
'dfs': self._get_node_dfs}
if mode in permissible_modes:
try:
if any([not self.node_gen, not self.chosen_print_mode,
mode != self.chosen_print_mode, refresh]):
self.chosen_print_mode = mode
raise StopIteration
print self.node_gen.next(),
except StopIteration:
self.node_gen = permissible_modes[mode](self.root)
print self.node_gen.next(),
else:
print 'Invalid mode requested!'
def _get_node_dfs(self, node):
'''
A DFS binary tree node-generator.
'''
nodes, cur = [], node
while nodes or cur:
if cur:
nodes.append(cur)
cur = cur.left
else:
cur = nodes.pop()
yield cur.value
cur = cur.right
def _get_node_bfs(self, node):
'''
A BFS binary tree node-generator.
'''
if node:
from Queue import Queue
nodes = Queue()
nodes.put(node)
while not nodes.empty():
x = nodes.get()
if x.left:
nodes.put(x.left)
if x.right:
nodes.put(x.right)
yield x.value
if __name__ == '__main__':
b = BinaryTree()
for i in range(15):
b.insert(i)
b.delete(14)
b.delete(2)
b.delete(0)
for i in range(22):
b.print_tree('dfs')
print ''
for i in range(12):
b.print_tree()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment