Skip to content

Instantly share code, notes, and snippets.

@whiler
Last active April 23, 2018 10:13
Show Gist options
  • Save whiler/c92e3e1c4d1e809a24811ffe46a8a8db to your computer and use it in GitHub Desktop.
Save whiler/c92e3e1c4d1e809a24811ffe46a8a8db to your computer and use it in GitHub Desktop.
Binary Heap implementation in Python language
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
class BaseBinaryHeap(object):
def __init__(self, key=lambda x: x, value=lambda x: x):
self.queue = list()
self.index = dict()
self.key = key
self.value = value
def __len__(self):
return len(self.queue)
def head(self, default=None):
return self.queue[0] if len(self.queue) else default
def push(self, node):
key = self.key(node)
self.queue.append(node)
self.index[key] = self._bubbleUp(len(self.queue) - 1, node,
key, self.value(node))
return True
def pop(self, default=None):
count = len(self.queue)
if count:
idx = 0
head = self.queue[idx]
node = self.queue.pop()
if count - 1:
key = self.key(node)
self.queue[idx] = node
self.index[key] = self._sinkDown(idx, node, key, self.value(node))
self.index.pop(self.key(head))
return head
else:
return default
def updateByKey(self, key):
idx = self.index.get(key, -1)
if idx != -1:
node = self.queue[idx]
value = self.value(node)
idx = self._bubbleUp(idx, node, key, value)
self.index[key] = self._sinkDown(idx, node, key, value)
return True
else:
return False
def removeByKey(self, key):
idx = self.index.pop(key, -1)
if idx != -1:
node = self.queue.pop()
if idx != len(self.queue):
self.queue[idx] = node
key = self.key(node)
value = self.value(node)
idx = self._bubbleUp(idx, node, key, value)
self.index[key] = self._sinkDown(idx, node, key, value)
return True
else:
return False
def keys(self):
return self.index.keys()
def get(self, key, default=None):
idx = self.index.get(key, -1)
return self.queue[idx] if idx != -1 else default
def _bubbleUp(self, idx, node, key, value):
while idx > 0:
parentIdx = ((idx + 1) >> 1) - 1
parent = self.queue[parentIdx]
if value >= self.value(parent):
break
self.queue[parentIdx] = node
self.queue[idx] = parent
self.index[self.key(parent)] = idx
idx = parentIdx
return idx
def _sinkDown(self, idx, node, key, value):
count = len(self.queue)
while True:
rightIdx = (idx + 1) << 1
leftIdx = rightIdx - 1
leftValue = None
swap = None
swapKey = None
if leftIdx < count:
left = self.queue[leftIdx]
leftValue = self.value(left)
if leftValue < value:
swap = leftIdx
swapKey = self.key(left)
if rightIdx < count:
right = self.queue[rightIdx]
if self.value(right) < (value if swap is None else leftValue):
swap = rightIdx
swapKey = self.key(right)
if swap is None:
break
else:
self.queue[idx] = self.queue[swap]
self.queue[swap] = node
self.index[swapKey] = idx
idx = swap
return idx
if __name__ == '__main__':
import heapq
import random
import timeit
# update
lis = list()
heap = BaseBinaryHeap(key=lambda o: o['key'], value=lambda o: o['value'])
for i in range(random.randint(1024, 4096)):
n = random.randint(0, 1<<31)
o = {'key': 'i-' + str(n), 'value': n}
lis.append(o)
heap.push(o)
while lis:
for key in heap.keys():
assert key == heap.get(key)['key']
key = random.choice(list(heap.keys()))
o = heap.get(key)
o['value'] = random.randint(0, 1<<31)
heap.updateByKey(key)
lis.sort(key=lambda o: o['value'])
assert len(lis) == len(heap), (len(lis), len(heap))
assert lis[0] == heap.head(), (lis[0], heap.head())
a, b = lis.pop(0), heap.pop()
assert a == b, (a, b)
# remove
lis = list()
heap = BaseBinaryHeap(key=lambda o: o['key'], value=lambda o: o['value'])
for i in range(random.randint(1024, 4096)):
n = random.randint(0, 1<<31)
o = {'key': str(n), 'value': n}
lis.append(o)
heap.push(o)
while lis:
for key in heap.keys():
assert key == heap.get(key)['key']
lis.sort(key=lambda o: o['value'])
assert len(lis) == len(heap), (len(lis), len(heap))
assert lis[0] == heap.head(), (lis[0], heap.head())
a, b = lis.pop(0), heap.pop()
assert a == b, (a, b)
if len(heap) and lis:
key = random.choice(list(heap.keys()))
heap.removeByKey(key)
lis.remove({'key': key, 'value': int(key)})
# benchmark
COUNT = 1000
sample = list(map(lambda x: random.randint(0, 1<<31), range(10240)))
timer = timeit.Timer('list(map(lambda x: heapq.heappush(lis, x), sample))',
globals={'heapq': heapq, 'sample': sample, 'lis': list()})
print('heapq.heappush', timer.timeit(COUNT))
timer = timeit.Timer("""
list(map(lambda x: heapq.heappush(lis, x), sample))
while lis:
heapq.heappop(lis)""",
globals={'heapq': heapq, 'sample': sample, 'lis': list()})
print('heapq', timer.timeit(COUNT))
timer = timeit.Timer('list(map(lambda x: heap.push(x), sample))',
globals={'heap': BaseBinaryHeap(), 'sample': sample})
print('BaseBinaryHeap.push', timer.timeit(COUNT))
timer = timeit.Timer("""
list(map(lambda x: heap.push(x), sample))
while len(heap):
heap.pop()""",
globals={'heap': BaseBinaryHeap(), 'sample': sample})
print('BaseBinaryHeap', timer.timeit(COUNT))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment