Created
April 3, 2017 10:24
-
-
Save jakab922/37ac59e360b231c3c6e9fa31e28657a5 to your computer and use it in GitHub Desktop.
Half baked red and black tree
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
RED, BLACK = range(2) | |
LEFT, RIGHT = range(2) | |
SMALLER, EQUAL, BIGGER = range(-1, 2) | |
def _set_relation(child, parent, rtype): | |
if child is not None: | |
child.parent = parent | |
if parent is not None: | |
if rtype == LEFT: | |
parent.left = child | |
else: | |
parent.right = child | |
def _set_side(one, other, side): | |
if side == LEFT: | |
one.left = other | |
else: | |
one.right = other | |
if other is not None: | |
other.parent = one | |
def _is_red(node): | |
return node is not None and node.color == RED | |
class Node(object): | |
def __init__(self, key, color, | |
left=None, right=None, parent=None): | |
self.key = key | |
self.color = color | |
self.left = left | |
if left is not None: | |
_set_relation(left, self, LEFT) | |
self.right = right | |
if right is not None: | |
_set_relation(right, self, RIGHT) | |
self.parent = parent | |
if parent is not None: | |
rtype = LEFT if key < parent.key else RIGHT | |
_set_relation(self, parent, rtype) | |
def __eq__(self, other): | |
return other is not None and self.key == other.key | |
def __str__(self): | |
return "%s - %s" % ( | |
"B" if self.color == BLACK else "R", self.key) | |
def __repr__(self): | |
return str(self.key) | |
def _delete(self): | |
assert self.key is None # Means this is a dummy node | |
if self.parent is not None: | |
if self.left_child: | |
self.parent.left = None | |
else: | |
self.parent.right = None | |
self.parent = None | |
@property | |
def left_child(self): | |
return self.parent is not None and self.parent.left == self | |
@property | |
def right_child(self): | |
return self.parent is not None and self.parent.right == self | |
@property | |
def root(self): | |
return self.parent is None | |
class RBT(object): | |
_nclass = Node | |
def __init__(self): | |
self._count = 0 | |
self._root = None | |
self._first_elem = None | |
self._last_elem = None | |
self._iter_state = None | |
def _find(self, key): | |
"""Finds an element with the given key.""" | |
curr, par = self._root, None | |
while curr is not None: | |
c = cmp(key, curr.key) | |
if c < 0: | |
par, curr = curr, curr.left | |
elif c == 0: | |
return True, curr | |
else: # c > 0 | |
par, curr = curr, curr.right | |
return False, par | |
def _first(self, cnode): | |
"""Returns the first node in the subtree rooted at *cnode*.""" | |
cand = cnode | |
while cnode is not None: | |
cand, cnode = cnode, cnode.left | |
return cand | |
def _last(self, cnode): | |
"""Returns the last node in the subtree rooted at *cnode*.""" | |
cand = cnode | |
while cnode is not None: | |
cand, cnode = cnode, cnode.right | |
return cand | |
def _next(self, node): | |
if node == self._last_elem: | |
return None | |
ret = node | |
if ret.right is not None: | |
ret = ret.right | |
while ret.left is not None: | |
ret = ret.left | |
return ret | |
else: | |
while ret.right_child: | |
ret = ret.parent | |
return ret.parent | |
def _prev(self, node): | |
if node == self._first_elem: | |
return None | |
ret = node | |
if ret.left is not None: | |
ret = ret.left | |
while ret.right is not None: | |
ret = ret.right | |
return ret | |
else: | |
while ret.left_child: | |
ret = ret.parent | |
return ret.parent | |
def _left_rotate(self, cnode): | |
onode = cnode.right | |
_set_relation(onode.left, cnode, RIGHT) | |
if cnode.parent is not None: | |
top_type = LEFT if cnode.left_child else RIGHT | |
_set_relation(onode, cnode.parent, top_type) | |
else: | |
onode.parent = None | |
self._root = onode | |
_set_relation(cnode, onode, LEFT) | |
def _right_rotate(self, cnode): | |
onode = cnode.left | |
_set_relation(onode.right, cnode, LEFT) | |
if cnode.parent is not None: | |
top_type = LEFT if cnode.left_child else RIGHT | |
_set_relation(onode, cnode.parent, top_type) | |
else: | |
onode.parent = None | |
self._root = onode | |
_set_relation(cnode, onode, RIGHT) | |
def _restore_after_insert(self, cnode): | |
"""Restores property 4 after an insert.""" | |
curr = cnode | |
if curr.parent is not None: | |
if curr.parent == self._first_elem and curr.left_child: | |
self._first_elem = curr | |
if curr.parent == self._last_elem and curr.right_child: | |
self._last_elem = curr | |
else: | |
self._first_elem = self._last_elem = curr | |
while curr.parent is not None and curr.parent.color == RED: | |
par, gpar = curr.parent, curr.parent.parent | |
if par == gpar.left: | |
unc = gpar.right | |
if unc is not None and unc.color == RED: | |
par.color, unc.color, gpar.color = ( | |
BLACK, BLACK, RED) | |
curr = gpar | |
else: | |
if curr == par.right: | |
curr = par | |
self._left_rotate(curr) | |
curr.parent.color = BLACK | |
curr.parent.parent.color = RED | |
self._right_rotate(curr.parent.parent) | |
else: | |
unc = gpar.left | |
if unc is not None and unc.color == RED: | |
par.color, unc.color, gpar.color = ( | |
BLACK, BLACK, RED) | |
curr = gpar | |
else: | |
if curr == par.left: | |
curr = par | |
self._right_rotate(curr) | |
curr.parent.color = BLACK | |
curr.parent.parent.color = RED | |
self._left_rotate(curr.parent.parent) | |
if curr.parent is None: | |
curr.color = BLACK | |
self._root = curr | |
def _fix_misc_on_delete(self, rnode): | |
if self._count == 1: # The removed element was the only one | |
self._first_elem = None | |
self._last_elem = None | |
else: | |
if rnode == self._first_elem: | |
self._first_elem = self._next(rnode) | |
elif rnode == self._last_elem: | |
self._last_elem = self._prev(rnode) | |
self._count -= 1 | |
def _transplant(self, one, other): | |
if one.parent is None: | |
if other.key is not None: | |
self._root = other | |
else: | |
self._root = None | |
elif one.parent.left == one: | |
one.parent.left = other | |
else: # one.parent.right == one | |
one.parent.right = other | |
if other is not None: | |
other.parent = one.parent | |
def _restore_after_delete(self, x): | |
print "_restore_after_delete got called" | |
while x.parent is not None: | |
xp = x.parent | |
if x.left_child: | |
w = xp.right | |
w_l = w.left | |
w_r = w.right | |
if _is_red(w): # case 7 | |
print "left case 7" | |
xp.color = RED | |
w.color = BLACK | |
self._left_rotate(xp) | |
else: | |
xp_r, w_l_r, w_r_r = map(_is_red, (xp, w_l, w_r)) | |
if not w_l_r and w_r_r: | |
if not xp_r: # case 1 | |
print "left case 1" | |
w_r.color = BLACK | |
self._left_rotate(xp) | |
else: # case 2 | |
print "left case 2" | |
xp.color = BLACK | |
w.color = RED | |
w_r.color = BLACK | |
self._left_rotate(xp) | |
x = self._root | |
elif w_l_r and not w_r_r: # case 3 | |
print "left case 3" | |
w.color = RED | |
w_l.color = BLACK | |
self._right_rotate(w) | |
elif not w_l_r and not w_r_r: # case 4 | |
print "left case 4" | |
w.color = RED | |
if xp.color == RED: | |
xp.color = BLACK | |
x = self._root | |
else: # xp.color == BLACK | |
x = xp | |
else: # w_l_r and w_r_r | |
if xp.color == BLACK: # case 5 | |
print "left case 5" | |
w_r.color = BLACK | |
self._left_rotate(xp) | |
else: # case 6 | |
print "left case 6" | |
xp.color = BLACK | |
w.color = RED | |
w_r.color = BLACK | |
self._left_rotate(xp) | |
x = self._root | |
else: # swapping left with right | |
w = xp.left | |
w_l = w.left | |
w_r = w.right | |
if _is_red(w): # case 7 | |
print "right case 7" | |
xp.color = RED | |
w.color = BLACK | |
self._right_rotate(xp) | |
else: | |
xp_r, w_l_r, w_r_r = map(_is_red, (xp, w_l, w_r)) | |
if not w_r_r and w_l_r: | |
if not xp_r: # case 1 | |
print "right case 1" | |
w_l.color = BLACK | |
self._right_rotate(xp) | |
else: # case 2 | |
print "right case 2" | |
xp.color = BLACK | |
w.color = RED | |
w_l.color = BLACK | |
self._right_rotate(xp) | |
x = self._root | |
elif w_r_r and not w_l_r: # case 3 | |
print "right case 3" | |
w.color = RED | |
w_r.color = BLACK | |
self._left_rotate(w) | |
elif not w_r_r and not w_l_r: # case 4 | |
print "right case 4" | |
w.color = RED | |
if xp.color == RED: | |
xp.color = BLACK | |
x = self._root | |
else: # xp.color == BLACK | |
x = xp | |
else: # w_r_r and w_l_r | |
if xp.color == BLACK: # case 5 | |
print "right case 5" | |
w_l.color = BLACK | |
self._right_rotate(xp) | |
else: # case 6 | |
print "right case 6" | |
xp.color = BLACK | |
w.color = RED | |
w_l.color = BLACK | |
self._right_rotate(xp) | |
x = self._root | |
# Utility functions | |
@property | |
def first(self): | |
return self._first_elem | |
@property | |
def last(self): | |
return self._last_elem | |
def next(self, node): | |
return self._next(node) | |
def prev(self, node): | |
return self._prev(node) | |
# Special functions | |
def __len__(self): | |
return self._count | |
def __iter__(self): | |
if self._root is not None: | |
self._iter_state = self._first(self._root) | |
yield self._iter_state | |
self._iter_state = self._next(self._iter_state) | |
while self._iter_state is not None: | |
yield self._iter_state | |
self._iter_state = self._next(self._iter_state) | |
def __reversed__(self): | |
if self._root is not None: | |
self._iter_state = self._last(self._root) | |
yield self._iter_state | |
self._iter_state = self._prev(self._iter_state) | |
while self._iter_state is not None: | |
yield self._iter_state | |
self._iter_state = self._prev(self._iter_state) | |
def __contains__(self, key): | |
return self._find(key)[0] | |
def __repr__(self): | |
"""pre-order traversal of the tree is the representation.""" | |
if self._root is None: | |
return "" | |
rpr, ident, stack = [], "|", [(self._root, 0)] | |
while stack: | |
curr, ilvl = stack.pop() | |
rpr.append(ident * ilvl + str(curr)) | |
if curr.right is not None: | |
stack.append((curr.right, ilvl + 1)) | |
if curr.left is not None: | |
stack.append((curr.left, ilvl + 1)) | |
return "\n".join(rpr) | |
# debugging functions | |
def _loop_detect(self): | |
if self._root.parent is not None: | |
return True | |
stack = [self._root] | |
keys = set([stack[-1].key]) | |
while stack: | |
curr = stack.pop() | |
l, r = curr.left, curr.right | |
for c in (l, r): | |
if c is not None: | |
if c.key in keys: | |
return True | |
else: | |
stack.append(c) | |
keys.add(c.key) | |
return False | |
def _check_tree(self): | |
""" We have to check the following rules: | |
- rule 2: The root is black | |
- rule 4: All children of a red node is black. | |
- rule 5: The black height of every node is well defined | |
""" | |
print "The representation of the tree is:" | |
print repr(self) | |
if self._root is not None: | |
assert self._root.color == BLACK # rule 2 | |
stack = [self._root] | |
else: | |
stack = [] | |
was = set() | |
bh = {} | |
nn = lambda x: x is not None | |
mi, ma = None, None | |
gmi = lambda x, y: y if x is None else min(x, y) | |
gma = lambda x, y: y if x is None else max(x, y) | |
is_black = lambda x: x is None or x.color == BLACK | |
while stack: | |
curr = stack.pop() | |
mi = gmi(mi, curr.key) | |
ma = gma(ma, curr.key) | |
l, r = curr.left, curr.right | |
if curr.key in was: # rule 5 | |
mod = 1 if curr.color == BLACK else 0 | |
if nn(l) and nn(r): | |
emsg = ( | |
"The black height of %s should be the same as " | |
"%s while it is %s and that is %s") | |
emsg = emsg % (l.key, r.key, bh[l.key], bh[r.key]) | |
assert bh[l.key] == bh[r.key], emsg | |
val = bh[l.key] | |
elif nn(l): | |
emsg = "The key of %s should be 1" % l.key | |
print "bh: %s" % (bh,) | |
assert bh[l.key] == 1, emsg | |
val = 1 | |
elif nn(r): | |
emsg = "The key of %s should be 1" % r.key | |
print "bh: %s" % (bh,) | |
assert bh[r.key] == 1, emsg | |
val = 1 | |
else: | |
val = 1 | |
bh[curr.key] = mod + val | |
continue | |
was.add(curr.key) | |
stack.append(curr) | |
if curr.color == RED: # rule 4 | |
assert all(map(is_black, (l, r))) | |
if nn(r): | |
stack.append(r) | |
if nn(l): | |
stack.append(l) | |
print "bh: %s" % (bh,) | |
emsg = ( | |
"The length of 'was' is: %s while the length of " | |
"self._count is %s") | |
emsg = emsg % (len(was), self._count) | |
assert len(was) == self._count, emsg | |
assert mi == getattr(self._first_elem, "key", None) | |
assert ma == getattr(self._last_elem, "key", None) | |
coll = [el.key for el in iter(tree)] | |
assert sorted(coll) == coll | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment