Created
March 28, 2018 03:01
-
-
Save ajtritt/4785109cb5f08de55bbc92a01d2f2ca6 to your computer and use it in GitHub Desktop.
An interval tree that only supports insertion and point-query
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
BLACK = 0 | |
RED = 1 | |
class Interval(object): | |
def __init__(self, start, end): | |
self.__start = start | |
self.__end = end | |
self.__max = end | |
self.__left = None | |
self.__right = None | |
# for balancing | |
self.__parent = None | |
self.__color = BLACK | |
def __repr__(self): | |
return "[%s, %s]" % (self.start, self.end) | |
@property | |
def start(self): | |
return self.__start | |
@property | |
def end(self): | |
return self.__end | |
@property | |
def max(self): | |
return self.__max | |
@max.setter | |
def max(self, val): | |
self.__max = val | |
@property | |
def left(self): | |
return self.__left | |
def __update_max(self): | |
c = self.max | |
r = -1 if self.right is None else self.right.max | |
l = -1 if self.left is None else self.left.max | |
self.max = max(c, l, r) | |
@left.setter | |
def left(self, val): | |
self.__left = val | |
val.parent = self | |
if self.max < val.max: | |
self.__max = val.max | |
self.__update_max() | |
@property | |
def right(self): | |
return self.__right | |
@right.setter | |
def right(self, val): | |
self.__right = val | |
val.parent = self | |
if self.max < val.max: | |
self.__max = val.max | |
self.__update_max() | |
@property | |
def parent(self): | |
return self.__parent | |
@parent.setter | |
def parent(self, val): | |
self.__parent = val | |
@property | |
def color(self): | |
return self.__color | |
@color.setter | |
def color(self, val): | |
self.__color = val | |
def compare(self, other): | |
if self.start < other.start: | |
return -1 | |
elif self.start == other.start: | |
if self.end < other.end: | |
return -1 | |
elif self.end == other.end: | |
return 0 | |
else: | |
return -1 | |
else: | |
return 1 | |
LEAF = Interval(-1,-1) | |
class IntervalTree(object): | |
def __init__(self): | |
self.__root = None | |
@property | |
def root(self): | |
return self.__root | |
def insert_old(self, new_node, tmp=None): | |
if tmp is None: | |
tmp == self.__root | |
if new_node.end > tmp.max: | |
tmp.max = new_node.end | |
res = tmp.compare(new_node) | |
if res <= 0: | |
if tmp.right is None: | |
tmp.right = new_node | |
else: | |
self.insert(new_node, self.right) | |
else: | |
if tmp.left is None: | |
tmp.left = new_node | |
else: | |
self.insert(new_node, self.left) | |
return new_node | |
@classmethod | |
def grandparent(cls, node): | |
if node.parent is None: | |
return None | |
return node.parent | |
@classmethod | |
def sibling(cls, node): | |
p = node.parent | |
if p is None: | |
return None | |
if p.left == node: | |
return p.right | |
else: | |
return p.left | |
@classmethod | |
def uncle(cls, node): | |
p = node.parent | |
g = cls.grandparent(node) | |
if g is None: | |
return None | |
return cls.sibling(p) | |
@classmethod | |
def swap_child(cls, parent, child, new_child): | |
if parent.left == child: | |
parent.left == new_child | |
else: | |
parent.right = new_child | |
@classmethod | |
def rotate_left(cls, node): | |
old_parent = node.parent | |
nnew = node.right | |
assert nnew != LEAF | |
node.right = nnew.left | |
nnew.left = node | |
cls.swap_child(old_parent, node, nnew) | |
@classmethod | |
def rotate_right(cls, node): | |
old_parent = node.parent | |
nnew = node.left | |
assert nnew != LEAF | |
node.left = nnew.right | |
nnew.right = node | |
cls.swap_child(old_parent, node, nnew) | |
@classmethod | |
def insert_recurse(cls, root, n): | |
if root is not None: | |
if n.end > root.max: | |
root.max = n.end | |
res = root.compare(n) | |
if res > 0: | |
if root.left is not LEAF: | |
cls.insert_recurse(root.left, n) | |
else: | |
root.left = n | |
elif root is not None: | |
if root.right is not LEAF: | |
cls.insert_recurse(root.right, n) | |
else: | |
root.right = n | |
n.parent = root | |
n.left = LEAF | |
n.right = LEAF | |
n.color = RED | |
@classmethod | |
def insert_repair(cls, n): | |
p = n.parent | |
if p is None: | |
n.color = BLACK | |
elif p.color == BLACK: | |
pass | |
else: | |
g = p.parent | |
u = cls.uncle(n) | |
if u.color == RED: | |
p.color = BLACK | |
u.color = BLACK | |
g.color = RED | |
cls.insert_repair(g) | |
else: | |
if n is g.left.right: | |
cls.rotate_left(n) | |
n = n.left | |
elif n is g.right.left: | |
cls.rotate_right(n) | |
n = n.right | |
p = n.parent | |
g = p.parent | |
if n == p.left: | |
cls.rotate_left(n) | |
else: | |
cls.rotate_right(n) | |
p.color = BLACK | |
g.color = RED | |
def insert(self, start, end): | |
n = Interval(start, end) | |
self.insert_recurse(self.root, n) | |
self.insert_repair(n) | |
self.__root = n | |
p = self.__root.parent | |
while p is not None: | |
self.__root = p | |
p = self.__root.parent | |
return self.__root | |
def query(self, point): | |
ret = list() | |
self.query_recurse(point, self.root, ret) | |
return ret | |
def query_recurse(self, point, root, ret): | |
if point <= root.end and point >= root.start: | |
ret.append(root) | |
if root.left is not None and root.left.max > point: | |
self.query_recurse(point, root.left, ret) | |
if root.right is not None: | |
self.query_recurse(point, root.right, ret) | |
def __repr__(self): | |
l = list() | |
self.walk(self.root, l) | |
return ", ".join(repr(i) for i in l) | |
def walk(self, root, l): | |
if root.left is not None: | |
self.walk(root.left, l) | |
if root != LEAF: | |
l.append(root) | |
if root.right is not None: | |
self.walk(root.right, l) | |
t = IntervalTree() | |
t.insert(2,10) | |
t.insert(6,15) | |
print(t) | |
print(t.query(8)) | |
print(t.query(4)) | |
print(t.query(14)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is a red-black tree where the key is the lower end of intervals. Additionally, each node holds the max value of the subtree starting at itself. This allows for O(log(n) + k) lookups, where k is the number of intervals that overlap with the point query.