Skip to content

Instantly share code, notes, and snippets.

@jakab922
Created April 3, 2017 10:24
Show Gist options
  • Save jakab922/37ac59e360b231c3c6e9fa31e28657a5 to your computer and use it in GitHub Desktop.
Save jakab922/37ac59e360b231c3c6e9fa31e28657a5 to your computer and use it in GitHub Desktop.
Half baked red and black tree
RED, BLACK = range(2)
LEFT, RIGHT = range(2)
SMALLER, EQUAL, BIGGER = range(-1, 2)
def _set_relation(child, parent, rtype):
if child is not None:
child.parent = parent
if parent is not None:
if rtype == LEFT:
parent.left = child
else:
parent.right = child
def _set_side(one, other, side):
if side == LEFT:
one.left = other
else:
one.right = other
if other is not None:
other.parent = one
def _is_red(node):
return node is not None and node.color == RED
class Node(object):
def __init__(self, key, color,
left=None, right=None, parent=None):
self.key = key
self.color = color
self.left = left
if left is not None:
_set_relation(left, self, LEFT)
self.right = right
if right is not None:
_set_relation(right, self, RIGHT)
self.parent = parent
if parent is not None:
rtype = LEFT if key < parent.key else RIGHT
_set_relation(self, parent, rtype)
def __eq__(self, other):
return other is not None and self.key == other.key
def __str__(self):
return "%s - %s" % (
"B" if self.color == BLACK else "R", self.key)
def __repr__(self):
return str(self.key)
def _delete(self):
assert self.key is None # Means this is a dummy node
if self.parent is not None:
if self.left_child:
self.parent.left = None
else:
self.parent.right = None
self.parent = None
@property
def left_child(self):
return self.parent is not None and self.parent.left == self
@property
def right_child(self):
return self.parent is not None and self.parent.right == self
@property
def root(self):
return self.parent is None
class RBT(object):
_nclass = Node
def __init__(self):
self._count = 0
self._root = None
self._first_elem = None
self._last_elem = None
self._iter_state = None
def _find(self, key):
"""Finds an element with the given key."""
curr, par = self._root, None
while curr is not None:
c = cmp(key, curr.key)
if c < 0:
par, curr = curr, curr.left
elif c == 0:
return True, curr
else: # c > 0
par, curr = curr, curr.right
return False, par
def _first(self, cnode):
"""Returns the first node in the subtree rooted at *cnode*."""
cand = cnode
while cnode is not None:
cand, cnode = cnode, cnode.left
return cand
def _last(self, cnode):
"""Returns the last node in the subtree rooted at *cnode*."""
cand = cnode
while cnode is not None:
cand, cnode = cnode, cnode.right
return cand
def _next(self, node):
if node == self._last_elem:
return None
ret = node
if ret.right is not None:
ret = ret.right
while ret.left is not None:
ret = ret.left
return ret
else:
while ret.right_child:
ret = ret.parent
return ret.parent
def _prev(self, node):
if node == self._first_elem:
return None
ret = node
if ret.left is not None:
ret = ret.left
while ret.right is not None:
ret = ret.right
return ret
else:
while ret.left_child:
ret = ret.parent
return ret.parent
def _left_rotate(self, cnode):
onode = cnode.right
_set_relation(onode.left, cnode, RIGHT)
if cnode.parent is not None:
top_type = LEFT if cnode.left_child else RIGHT
_set_relation(onode, cnode.parent, top_type)
else:
onode.parent = None
self._root = onode
_set_relation(cnode, onode, LEFT)
def _right_rotate(self, cnode):
onode = cnode.left
_set_relation(onode.right, cnode, LEFT)
if cnode.parent is not None:
top_type = LEFT if cnode.left_child else RIGHT
_set_relation(onode, cnode.parent, top_type)
else:
onode.parent = None
self._root = onode
_set_relation(cnode, onode, RIGHT)
def _restore_after_insert(self, cnode):
"""Restores property 4 after an insert."""
curr = cnode
if curr.parent is not None:
if curr.parent == self._first_elem and curr.left_child:
self._first_elem = curr
if curr.parent == self._last_elem and curr.right_child:
self._last_elem = curr
else:
self._first_elem = self._last_elem = curr
while curr.parent is not None and curr.parent.color == RED:
par, gpar = curr.parent, curr.parent.parent
if par == gpar.left:
unc = gpar.right
if unc is not None and unc.color == RED:
par.color, unc.color, gpar.color = (
BLACK, BLACK, RED)
curr = gpar
else:
if curr == par.right:
curr = par
self._left_rotate(curr)
curr.parent.color = BLACK
curr.parent.parent.color = RED
self._right_rotate(curr.parent.parent)
else:
unc = gpar.left
if unc is not None and unc.color == RED:
par.color, unc.color, gpar.color = (
BLACK, BLACK, RED)
curr = gpar
else:
if curr == par.left:
curr = par
self._right_rotate(curr)
curr.parent.color = BLACK
curr.parent.parent.color = RED
self._left_rotate(curr.parent.parent)
if curr.parent is None:
curr.color = BLACK
self._root = curr
def _fix_misc_on_delete(self, rnode):
if self._count == 1: # The removed element was the only one
self._first_elem = None
self._last_elem = None
else:
if rnode == self._first_elem:
self._first_elem = self._next(rnode)
elif rnode == self._last_elem:
self._last_elem = self._prev(rnode)
self._count -= 1
def _transplant(self, one, other):
if one.parent is None:
if other.key is not None:
self._root = other
else:
self._root = None
elif one.parent.left == one:
one.parent.left = other
else: # one.parent.right == one
one.parent.right = other
if other is not None:
other.parent = one.parent
def _restore_after_delete(self, x):
print "_restore_after_delete got called"
while x.parent is not None:
xp = x.parent
if x.left_child:
w = xp.right
w_l = w.left
w_r = w.right
if _is_red(w): # case 7
print "left case 7"
xp.color = RED
w.color = BLACK
self._left_rotate(xp)
else:
xp_r, w_l_r, w_r_r = map(_is_red, (xp, w_l, w_r))
if not w_l_r and w_r_r:
if not xp_r: # case 1
print "left case 1"
w_r.color = BLACK
self._left_rotate(xp)
else: # case 2
print "left case 2"
xp.color = BLACK
w.color = RED
w_r.color = BLACK
self._left_rotate(xp)
x = self._root
elif w_l_r and not w_r_r: # case 3
print "left case 3"
w.color = RED
w_l.color = BLACK
self._right_rotate(w)
elif not w_l_r and not w_r_r: # case 4
print "left case 4"
w.color = RED
if xp.color == RED:
xp.color = BLACK
x = self._root
else: # xp.color == BLACK
x = xp
else: # w_l_r and w_r_r
if xp.color == BLACK: # case 5
print "left case 5"
w_r.color = BLACK
self._left_rotate(xp)
else: # case 6
print "left case 6"
xp.color = BLACK
w.color = RED
w_r.color = BLACK
self._left_rotate(xp)
x = self._root
else: # swapping left with right
w = xp.left
w_l = w.left
w_r = w.right
if _is_red(w): # case 7
print "right case 7"
xp.color = RED
w.color = BLACK
self._right_rotate(xp)
else:
xp_r, w_l_r, w_r_r = map(_is_red, (xp, w_l, w_r))
if not w_r_r and w_l_r:
if not xp_r: # case 1
print "right case 1"
w_l.color = BLACK
self._right_rotate(xp)
else: # case 2
print "right case 2"
xp.color = BLACK
w.color = RED
w_l.color = BLACK
self._right_rotate(xp)
x = self._root
elif w_r_r and not w_l_r: # case 3
print "right case 3"
w.color = RED
w_r.color = BLACK
self._left_rotate(w)
elif not w_r_r and not w_l_r: # case 4
print "right case 4"
w.color = RED
if xp.color == RED:
xp.color = BLACK
x = self._root
else: # xp.color == BLACK
x = xp
else: # w_r_r and w_l_r
if xp.color == BLACK: # case 5
print "right case 5"
w_l.color = BLACK
self._right_rotate(xp)
else: # case 6
print "right case 6"
xp.color = BLACK
w.color = RED
w_l.color = BLACK
self._right_rotate(xp)
x = self._root
# Utility functions
@property
def first(self):
return self._first_elem
@property
def last(self):
return self._last_elem
def next(self, node):
return self._next(node)
def prev(self, node):
return self._prev(node)
# Special functions
def __len__(self):
return self._count
def __iter__(self):
if self._root is not None:
self._iter_state = self._first(self._root)
yield self._iter_state
self._iter_state = self._next(self._iter_state)
while self._iter_state is not None:
yield self._iter_state
self._iter_state = self._next(self._iter_state)
def __reversed__(self):
if self._root is not None:
self._iter_state = self._last(self._root)
yield self._iter_state
self._iter_state = self._prev(self._iter_state)
while self._iter_state is not None:
yield self._iter_state
self._iter_state = self._prev(self._iter_state)
def __contains__(self, key):
return self._find(key)[0]
def __repr__(self):
"""pre-order traversal of the tree is the representation."""
if self._root is None:
return ""
rpr, ident, stack = [], "|", [(self._root, 0)]
while stack:
curr, ilvl = stack.pop()
rpr.append(ident * ilvl + str(curr))
if curr.right is not None:
stack.append((curr.right, ilvl + 1))
if curr.left is not None:
stack.append((curr.left, ilvl + 1))
return "\n".join(rpr)
# debugging functions
def _loop_detect(self):
if self._root.parent is not None:
return True
stack = [self._root]
keys = set([stack[-1].key])
while stack:
curr = stack.pop()
l, r = curr.left, curr.right
for c in (l, r):
if c is not None:
if c.key in keys:
return True
else:
stack.append(c)
keys.add(c.key)
return False
def _check_tree(self):
""" We have to check the following rules:
- rule 2: The root is black
- rule 4: All children of a red node is black.
- rule 5: The black height of every node is well defined
"""
print "The representation of the tree is:"
print repr(self)
if self._root is not None:
assert self._root.color == BLACK # rule 2
stack = [self._root]
else:
stack = []
was = set()
bh = {}
nn = lambda x: x is not None
mi, ma = None, None
gmi = lambda x, y: y if x is None else min(x, y)
gma = lambda x, y: y if x is None else max(x, y)
is_black = lambda x: x is None or x.color == BLACK
while stack:
curr = stack.pop()
mi = gmi(mi, curr.key)
ma = gma(ma, curr.key)
l, r = curr.left, curr.right
if curr.key in was: # rule 5
mod = 1 if curr.color == BLACK else 0
if nn(l) and nn(r):
emsg = (
"The black height of %s should be the same as "
"%s while it is %s and that is %s")
emsg = emsg % (l.key, r.key, bh[l.key], bh[r.key])
assert bh[l.key] == bh[r.key], emsg
val = bh[l.key]
elif nn(l):
emsg = "The key of %s should be 1" % l.key
print "bh: %s" % (bh,)
assert bh[l.key] == 1, emsg
val = 1
elif nn(r):
emsg = "The key of %s should be 1" % r.key
print "bh: %s" % (bh,)
assert bh[r.key] == 1, emsg
val = 1
else:
val = 1
bh[curr.key] = mod + val
continue
was.add(curr.key)
stack.append(curr)
if curr.color == RED: # rule 4
assert all(map(is_black, (l, r)))
if nn(r):
stack.append(r)
if nn(l):
stack.append(l)
print "bh: %s" % (bh,)
emsg = (
"The length of 'was' is: %s while the length of "
"self._count is %s")
emsg = emsg % (len(was), self._count)
assert len(was) == self._count, emsg
assert mi == getattr(self._first_elem, "key", None)
assert ma == getattr(self._last_elem, "key", None)
coll = [el.key for el in iter(tree)]
assert sorted(coll) == coll
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment