Skip to content

Instantly share code, notes, and snippets.

@jakab922
Created March 2, 2017 17:11
Show Gist options
  • Save jakab922/4be804fe76724bddadb7d6acc5da16d3 to your computer and use it in GitHub Desktop.
Save jakab922/4be804fe76724bddadb7d6acc5da16d3 to your computer and use it in GitHub Desktop.
Red and black tree insert operation implemented. It's an order of magnitude slower than the stl set so I guess there is space for improvement.
#!python
#cython: boundscheck=False, wraparound=False, initializedcheck=False
from cpython cimport bool
LEFT, RIGHT = range(2)
RED, BLACK = range(2)
cocodes = ["R", "B"]
chcodes = ["l", "r"]
cdef void _set_relation(Node child, Node parent, int rtype):
if child is not None:
child.parent = parent
if parent is not None:
if rtype == LEFT:
parent.left = child
else:
parent.right = child
cdef Node _left_child(Node node):
return node.parent is not None and node.parent.left == node
cdef Node _right_child(Node node):
return node.parent is not None and node.parent.right == node
cdef class Node:
cdef public object key
cdef public int color
cdef public Node left
cdef public Node right
cdef public Node parent
def __init__(self, object key, int color,
Node left=None, Node right=None, Node parent=None):
self.key = key
self.color = color
self.left = left
if left is not None:
_set_relation(left, self, LEFT)
self.right = right
if right is not None:
_set_relation(right, self, RIGHT)
self.parent = parent
if parent is not None:
rtype = LEFT if key < parent.key else RIGHT
_set_relation(self, parent, rtype)
def __str__(self):
side = 0 if _left_child(self) else 1
return "%s(%s) - %s" % (
cocodes[self.color], chcodes[side], self.key)
cdef class RBT:
cdef int _count
cdef Node _root
cdef Node _first_elem
cdef Node _last_elem
cdef Node _iter_state
def __init__(self):
self._count = 0
self._root = None
self._first_elem = None
self._last_elem = None
self._iter_state = None
cdef _find(self, object key):
cdef Node curr, par
cdef int c
curr, par = self._root, None
while curr is not None:
c = cmp(key, curr.key)
if c < 0:
par, curr = curr, curr.left
elif c == 0:
return True, curr
else: # c > 0
par, curr = curr, curr.right
return False, par
cdef void _left_rotate(self, Node cnode):
cdef Node onode = cnode.right
cdef int top_type
_set_relation(onode.left, cnode, RIGHT)
if cnode.parent is not None:
top_type = LEFT if _left_child(cnode) else RIGHT
_set_relation(onode, cnode.parent, top_type)
else:
onode.parent = None
self._root = onode
_set_relation(cnode, onode, LEFT)
cdef void _right_rotate(self, Node cnode):
cdef Node onode = cnode.left
cdef int top_type
_set_relation(onode.right, cnode, LEFT)
if cnode.parent is not None:
top_type = LEFT if _left_child(cnode) else RIGHT
_set_relation(onode, cnode.parent, top_type)
else:
onode.parent = None
self._root = onode
_set_relation(cnode, onode, RIGHT)
def add(self, object key):
cdef bool is_in
cdef Node parent, node
is_in, parent = self._find(key)
if not is_in:
self._count += 1
node = Node(key, RED, None, None, parent)
self._restore_after_insert(node)
cdef void _restore_after_insert(self, Node cnode):
cdef Node curr, par, gpar, unc
curr = cnode
if curr.parent is not None:
if curr.parent == self._first_elem and _left_child(curr):
self._first_elem = curr
if curr.parent == self._last_elem and _right_child(curr):
self._last_elem = curr
else:
self._first_elem = self._last_elem = curr
while curr.parent is not None and curr.parent.color == RED:
par = curr.parent
gpar = curr.parent.parent
if par == gpar.left:
unc = gpar.right
if unc is not None and unc.color == RED:
par.color, unc.color, gpar.color = (
BLACK, BLACK, RED)
curr = gpar
else:
if curr == par.right:
curr = par
self._left_rotate(curr)
curr.parent.color = BLACK
curr.parent.parent.color = RED
self._right_rotate(curr.parent.parent)
else:
unc = gpar.left
if unc is not None and unc.color == RED:
par.color, unc.color, gpar.color = (
BLACK, BLACK, RED)
curr = gpar
else:
if curr == par.left:
curr = par
self._right_rotate(curr)
curr.parent.color = BLACK
curr.parent.parent.color = RED
self._left_rotate(curr.parent.parent)
if curr.parent is None:
curr.color = BLACK
self._root = curr
def __repr__(self):
"""pre-order traversal of the tree is the representation."""
if self._root is None:
return ""
rpr, ident, stack = [], "|", [(self._root, 0)]
while stack:
curr, ilvl = stack.pop()
rpr.append(ident * ilvl + str(curr))
if curr.right is not None:
stack.append((curr.right, ilvl + 1))
if curr.left is not None:
stack.append((curr.left, ilvl + 1))
return "\n".join(rpr)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment