Created
July 14, 2024 20:02
-
-
Save proger/4c4d4fd0eebce88388f976087f27da76 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Randomized Binary Search Trees | |
https://www.cs.upc.edu/~conrado/research/papers/jacm-mr98.pdf | |
""" | |
import math | |
import random | |
from collections import Counter | |
class root: | |
__slots__ = 'x', 'left', 'right', 'size' | |
def __init__(self, x, left=None, right=None): | |
self.x = x | |
self.left = left | |
self.right = right | |
left_size = self.left.size if left else 0 | |
right_size = self.right.size if right else 0 | |
self.size = 1 + left_size + right_size | |
def __repr__(self): | |
return f'({self.x} {self.left or "_"} {self.right or "_"})' | |
def depth(self): | |
if not self: | |
return 0 | |
return 1 + max(depth(self.left), depth(self.right)) | |
def split(x, tree): | |
if not tree: | |
return None, None | |
if x <= tree.x: | |
left_left, left_right = split(x, tree.left) | |
return left_left, root(tree.x, left=left_right, right=tree.right) | |
else: | |
right_left, right_right = split(x, tree.right) | |
return root(tree.x, left=tree.left, right=right_left), right_right | |
def insert(x, tree): | |
if not tree: | |
return root(x) | |
n = tree.size + 1 | |
if random.random() < 1/n: | |
left, right = split(x, tree) | |
return root(x, left=left, right=right) | |
else: | |
if tree.x <= x: | |
return root(tree.x, left=tree.left, right=insert(x, tree.right)) | |
else: # tree.x > x | |
return root(tree.x, left=insert(x, tree.left), right=tree.right) | |
def join(left, right): | |
if not left: | |
return right | |
if not right: | |
return left | |
size = left.size + right.size | |
if random.random() < left.size/size: | |
return root(left.x, left=left.left, right=join(left.right, right)) | |
else: | |
return root(right.x, left=join(left, right.left), right=right.right) | |
def delete(x, tree): | |
if not tree: | |
return None | |
if tree.x == x: | |
return join(tree.left, tree.right) | |
if x < tree.x: | |
return root(tree.x, left=delete(x, tree.left), right=tree.right) | |
else: | |
return root(tree.x, left=tree.left, right=delete(x, tree.right)) | |
# | |
# let's simulate multiple trees with the same input and see how deep they get | |
# | |
def go(xs, tree=None): | |
tree = None | |
for x in xs: | |
tree = insert(x, tree) | |
return tree | |
random.seed(42) | |
N = 1000 | |
S = 32 | |
counter = Counter( | |
depth(go(range(1,S+1))) for _ in range(N) | |
) | |
print('depth distribution after', N, 'trials:') | |
for d, c in sorted(counter.items()): | |
print(d, '|'*int(c*100/N), c/N) | |
print('perfect depth:', math.log2(S)) | |
xs = list(range(1,S+1)) | |
t = go(xs) | |
print(f'after inserting {S} elements', t) | |
for x in xs[::2]: # delete odd | |
t = delete(x, t) | |
print('after deleting odds:', t) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment