Last active
April 23, 2018 10:13
-
-
Save whiler/c92e3e1c4d1e809a24811ffe46a8a8db to your computer and use it in GitHub Desktop.
Binary Heap implementation in Python language
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# | |
class BaseBinaryHeap(object): | |
def __init__(self, key=lambda x: x, value=lambda x: x): | |
self.queue = list() | |
self.index = dict() | |
self.key = key | |
self.value = value | |
def __len__(self): | |
return len(self.queue) | |
def head(self, default=None): | |
return self.queue[0] if len(self.queue) else default | |
def push(self, node): | |
key = self.key(node) | |
self.queue.append(node) | |
self.index[key] = self._bubbleUp(len(self.queue) - 1, node, | |
key, self.value(node)) | |
return True | |
def pop(self, default=None): | |
count = len(self.queue) | |
if count: | |
idx = 0 | |
head = self.queue[idx] | |
node = self.queue.pop() | |
if count - 1: | |
key = self.key(node) | |
self.queue[idx] = node | |
self.index[key] = self._sinkDown(idx, node, key, self.value(node)) | |
self.index.pop(self.key(head)) | |
return head | |
else: | |
return default | |
def updateByKey(self, key): | |
idx = self.index.get(key, -1) | |
if idx != -1: | |
node = self.queue[idx] | |
value = self.value(node) | |
idx = self._bubbleUp(idx, node, key, value) | |
self.index[key] = self._sinkDown(idx, node, key, value) | |
return True | |
else: | |
return False | |
def removeByKey(self, key): | |
idx = self.index.pop(key, -1) | |
if idx != -1: | |
node = self.queue.pop() | |
if idx != len(self.queue): | |
self.queue[idx] = node | |
key = self.key(node) | |
value = self.value(node) | |
idx = self._bubbleUp(idx, node, key, value) | |
self.index[key] = self._sinkDown(idx, node, key, value) | |
return True | |
else: | |
return False | |
def keys(self): | |
return self.index.keys() | |
def get(self, key, default=None): | |
idx = self.index.get(key, -1) | |
return self.queue[idx] if idx != -1 else default | |
def _bubbleUp(self, idx, node, key, value): | |
while idx > 0: | |
parentIdx = ((idx + 1) >> 1) - 1 | |
parent = self.queue[parentIdx] | |
if value >= self.value(parent): | |
break | |
self.queue[parentIdx] = node | |
self.queue[idx] = parent | |
self.index[self.key(parent)] = idx | |
idx = parentIdx | |
return idx | |
def _sinkDown(self, idx, node, key, value): | |
count = len(self.queue) | |
while True: | |
rightIdx = (idx + 1) << 1 | |
leftIdx = rightIdx - 1 | |
leftValue = None | |
swap = None | |
swapKey = None | |
if leftIdx < count: | |
left = self.queue[leftIdx] | |
leftValue = self.value(left) | |
if leftValue < value: | |
swap = leftIdx | |
swapKey = self.key(left) | |
if rightIdx < count: | |
right = self.queue[rightIdx] | |
if self.value(right) < (value if swap is None else leftValue): | |
swap = rightIdx | |
swapKey = self.key(right) | |
if swap is None: | |
break | |
else: | |
self.queue[idx] = self.queue[swap] | |
self.queue[swap] = node | |
self.index[swapKey] = idx | |
idx = swap | |
return idx | |
if __name__ == '__main__': | |
import heapq | |
import random | |
import timeit | |
# update | |
lis = list() | |
heap = BaseBinaryHeap(key=lambda o: o['key'], value=lambda o: o['value']) | |
for i in range(random.randint(1024, 4096)): | |
n = random.randint(0, 1<<31) | |
o = {'key': 'i-' + str(n), 'value': n} | |
lis.append(o) | |
heap.push(o) | |
while lis: | |
for key in heap.keys(): | |
assert key == heap.get(key)['key'] | |
key = random.choice(list(heap.keys())) | |
o = heap.get(key) | |
o['value'] = random.randint(0, 1<<31) | |
heap.updateByKey(key) | |
lis.sort(key=lambda o: o['value']) | |
assert len(lis) == len(heap), (len(lis), len(heap)) | |
assert lis[0] == heap.head(), (lis[0], heap.head()) | |
a, b = lis.pop(0), heap.pop() | |
assert a == b, (a, b) | |
# remove | |
lis = list() | |
heap = BaseBinaryHeap(key=lambda o: o['key'], value=lambda o: o['value']) | |
for i in range(random.randint(1024, 4096)): | |
n = random.randint(0, 1<<31) | |
o = {'key': str(n), 'value': n} | |
lis.append(o) | |
heap.push(o) | |
while lis: | |
for key in heap.keys(): | |
assert key == heap.get(key)['key'] | |
lis.sort(key=lambda o: o['value']) | |
assert len(lis) == len(heap), (len(lis), len(heap)) | |
assert lis[0] == heap.head(), (lis[0], heap.head()) | |
a, b = lis.pop(0), heap.pop() | |
assert a == b, (a, b) | |
if len(heap) and lis: | |
key = random.choice(list(heap.keys())) | |
heap.removeByKey(key) | |
lis.remove({'key': key, 'value': int(key)}) | |
# benchmark | |
COUNT = 1000 | |
sample = list(map(lambda x: random.randint(0, 1<<31), range(10240))) | |
timer = timeit.Timer('list(map(lambda x: heapq.heappush(lis, x), sample))', | |
globals={'heapq': heapq, 'sample': sample, 'lis': list()}) | |
print('heapq.heappush', timer.timeit(COUNT)) | |
timer = timeit.Timer(""" | |
list(map(lambda x: heapq.heappush(lis, x), sample)) | |
while lis: | |
heapq.heappop(lis)""", | |
globals={'heapq': heapq, 'sample': sample, 'lis': list()}) | |
print('heapq', timer.timeit(COUNT)) | |
timer = timeit.Timer('list(map(lambda x: heap.push(x), sample))', | |
globals={'heap': BaseBinaryHeap(), 'sample': sample}) | |
print('BaseBinaryHeap.push', timer.timeit(COUNT)) | |
timer = timeit.Timer(""" | |
list(map(lambda x: heap.push(x), sample)) | |
while len(heap): | |
heap.pop()""", | |
globals={'heap': BaseBinaryHeap(), 'sample': sample}) | |
print('BaseBinaryHeap', timer.timeit(COUNT)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment