Skip to content

Instantly share code, notes, and snippets.

@mrdomino
Last active September 15, 2023 01:28
Show Gist options
  • Save mrdomino/c366dddfc32133209e21 to your computer and use it in GitHub Desktop.
Save mrdomino/c366dddfc32133209e21 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
"""
Example demonstrating O(1)-space iteration over binary trees.
"""
import weakref
class Node:
"""
A binary tree node with left-child, right-child, and parent pointers.
"""
def __init__(self, val, left=None, right=None, parent=None):
"""
Construct a new Node.
If left and right are specified, their parent pointers are modified to
point to self.
>>> a=Node(3)
>>> b=Node(5)
>>> c=Node(4, a, b)
>>> all([a.parent == c, a == c.left, b.parent == c, b == c.right])
True
"""
self._left = self._right = self._parent = None
self.val = val
self.left = left
self.right = right
self.parent = parent
def __get_left(self):
return self._left
def __set_left(self, left):
if left:
left.parent = self
self._left = left
left = property(__get_left, __set_left)
def __get_right(self):
return self._right
def __set_right(self, right):
if right:
right.parent = self
self._right = right
right = property(__get_right, __set_right)
def __get_parent(self):
return self._parent
def __set_parent(self, parent):
self._parent = weakref.proxy(parent) if parent else None
parent = property(__get_parent, __set_parent)
def insert(self, val):
"""
Insert an item, preserving the BST property.
Inserts a new value into the tree. If the value to be inserted is less
than the current node, it is inserted into the node's left subtree. If
it is greater, it is inserted into the node's right subtree. Duplicate
values are ignored.
>>> a = Node(3)
>>> a.insert(4)
>>> a.insert(2)
>>> a.insert(3)
>>> print(list(iter(a)))
[2, 3, 4]
"""
if val < self.val:
if self.left:
self.left.insert(val)
else:
self.left = Node(val)
elif val > self.val:
if self.right:
self.right.insert(val)
else:
self.right = Node(val)
def __iter__(self):
return NodeIter(self)
def __str__(self):
params = [f'{self.val}']
if self.left:
params.append(f'left={self.left}')
if self.right:
params.append(f'right={self.right}')
return f"Node({', '.join(params)})"
def pretty(self, depth=0):
"""
Pretty-print a node.
One element is printed per line, indented just farther than its parent.
Each node's left-hand subtree is above it, and its right-hand subtree
is below.
>>> print(Node(5, left=Node(3), right=Node(6)).pretty())
3
5
6
"""
lines = []
if self.left:
lines.append(self.left.pretty(depth + 1))
lines.append((" " * depth) + str(self.val))
if self.right:
lines.append(self.right.pretty(depth + 1))
return '\n'.join(lines)
class NodeIter: # pylint: disable=too-few-public-methods
"""
In-order iteration over a BST with parent, left, and right pointers.
"""
def __init__(self, root=None, curr=None):
"""
Construct a NodeIter over the given tree.
Exactly one of root and curr must be passed. If root is passed, the
iterator starts at the least (i.e. leftmost) element in the tree. If
curr is passed, the iterator starts at that node exactly.
>>> a=Node(3)
>>> b=Node(5)
>>> c=Node(4, a, b)
>>> d=Node(8)
>>> e=Node(9, d)
>>> f=Node(6, c, e)
>>> list(NodeIter(f))
[3, 4, 5, 6, 8, 9]
>>> list(NodeIter(curr=f))
[6, 8, 9]
"""
assert root is None or curr is None
if root:
self.curr = root
while self.curr.left:
self.curr = self.curr.left
else:
self.curr = curr
def __iter__(self):
return self
def __str__(self):
return f'NodeIter(curr={self.curr})'
def __advance(self):
"""
Advance to the next element in the tree.
If there is a right subtree, then the next element is the least element
in that subtree, i.e. right, then left as far as we can go.
Otherwise, we are as far right in the current subtree as we can go, and
the next element is the closest parent that we are a left-hand child
of. We find it by walking up as long as we are the right-hand child of
our parent, then walking up one more level. When we get to the edge of
the tree, we return to the root's parent, i.e. None.
This uses O(1) space and O(d) time where d is the depth of the tree.
"""
assert self.curr
if self.curr.right:
self.curr = self.curr.right
while self.curr.left:
self.curr = self.curr.left
else:
left_parent = self.curr.parent
while left_parent and left_parent.right == self.curr:
left_parent = left_parent.parent
self.curr = self.curr.parent
self.curr = self.curr.parent
def __next__(self):
"""
Return the next element in the tree.
"""
if not self.curr:
raise StopIteration
ret = self.curr.val
self.__advance()
return ret
def _test(): # pylint: disable=missing-docstring
import doctest # pylint: disable=import-outside-toplevel
return doctest.testmod()
def _example(): # pylint: disable=missing-docstring
tree = Node(11,
Node(9,
Node(3,
Node(1),
Node(4, None,
Node(5,
Node(4.5,
Node(4.25))))),
Node(10,
right=Node(10.5,
right=Node(10.75, right=Node(10.9))))),
Node(13,
Node(12,
right=Node(12.5,
Node(12.25,
right=Node(12.375)))),
Node(15)))
print(f'>>> tree = {tree}')
print('>>> print(tree.pretty())')
print(tree.pretty())
print('>>> print(list(NodeIter(tree)))')
print(list(NodeIter(tree)))
def _run(): # pylint: disable=missing-docstring
failures, _ = _test()
if failures == 0:
_example()
if __name__ == '__main__':
_run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment