Created
June 15, 2013 00:27
-
-
Save almarklein/5786244 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from skimage.graph.heap import BinaryHeap as BinaryHeap_sk | |
import time | |
VALUE_T = nb.float64 | |
REFERENCE_T = nb.int32 | |
@nb.jit # jit is slightly faster, compile time for autojit is ~2s | |
class BinaryHeap: | |
@nb.void(nb.int32) | |
def __init__(self, initial_capacity=128): | |
# calc levels from the default capacity | |
levels = 0 | |
while 2**levels < initial_capacity: | |
levels += 1 | |
# set levels | |
self.min_levels = nb.int32(levels) | |
self.levels = nb.int32(levels) | |
# we start with 0 values | |
self.count = nb.int32(0) | |
self._popped_ref = REFERENCE_T(0) | |
# allocate arrays | |
number = 2**self.levels | |
self._values = np.empty((2*number,), dtype=np.float64) | |
self._references = np.empty((number,), dtype=np.int32) | |
self.reset() | |
@nb.void() | |
def reset(self): | |
"""reset() | |
Reset the heap to default, empty state. | |
""" | |
number = 2**self.levels | |
for i in range(number*2): | |
self._values[i] = np.inf | |
@nb.void(nb.int32) | |
def _add_or_remove_level(self, add_or_remove): | |
#cdef int i, i1, i2, n | |
# new amount of levels | |
new_levels = self.levels + add_or_remove | |
# allocate new arrays | |
number = 2**new_levels | |
values = np.empty((2*number,), dtype=np.float64) | |
references = np.empty((number,), dtype=np.int32) | |
# init arrays | |
for i in range(number*2): | |
values[i] = np.inf | |
for i in range(number): | |
references[i] = -1 | |
# copy data | |
old_values = self._values | |
old_references = self._references | |
if self.count: | |
i1 = 2**new_levels-1 # LevelStart | |
i2 = 2**self.levels-1 # LevelStart | |
n = min(2**new_levels, 2**self.levels) # min function! | |
for i in range(n): | |
values[i1+i] = old_values[i2+i] | |
for i in range(n): | |
references[i] = old_references[i] | |
# make current | |
self._values = values | |
self._references = references | |
# we need a full update | |
self.levels = new_levels | |
self._update() | |
@nb.void() | |
def _update(self): | |
"""Update the full tree from the bottom up. | |
This should be done after resizing. """ | |
# shorter name for values | |
values = self._values | |
#cdef int i0, i, ii, n, level | |
# track tree | |
for level in range(self.levels,1,-1): | |
i0 = (1 << level) - 1 #2**level-1 = LevelStart | |
n = i0 + 1 #2**level | |
for i in range(i0,i0+n,2): | |
ii = (i-1)//2 # CalcPrevAbs | |
if values[i] < values[i+1]: | |
values[ii] = values[i] | |
else: | |
values[ii] = values[i+1] | |
@nb.void(nb.int32) | |
def _update_one(self, i): | |
"""Update the tree for one value.""" | |
# shorter name for values | |
values = self._values | |
# make index uneven | |
if i % 2==0: | |
i = i-1 | |
# track tree | |
#cdef int ii, level | |
for level in range(self.levels,1,-1): | |
ii = (i-1)//2 # CalcPrevAbs | |
# test | |
if values[i] < values[i+1]: | |
values[ii] = values[i] | |
else: | |
values[ii] = values[i+1] | |
# next | |
if ii % 2: | |
i = ii | |
else: | |
i = ii-1 | |
@nb.int32(VALUE_T, nb.int32) | |
def push(self, value, reference): | |
"""The c-method for fast pushing. | |
Returns the index relative to the start of the last level in the heap.""" | |
# We need to resize if currently it just fits. | |
levels = self.levels | |
count = self.count | |
if count >= (1 << levels):#2**self.levels: | |
self._add_or_remove_level(+1) | |
levels += 1 | |
# insert new value | |
i = ((1 << levels) - 1) + count # LevelStart + n | |
self._values[i] = value | |
self._references[count] = reference | |
# update | |
self.count += 1 | |
self._update_one(i) | |
# return | |
return count | |
if __name__ == '__main__': | |
a = np.random.rand(1000000) | |
for HeapClass in BinaryHeap_sk, BinaryHeap: | |
t0 = time.time() | |
heap = HeapClass(8) | |
for num in a[:3]: | |
heap.push(num, 2) | |
etime = time.time() - t0 | |
print('compiling %s took %1.3f s' % (heap.__class__.__module__, etime)) | |
# | |
t0 = time.time() | |
for num in a: | |
heap.push(num, 2) | |
etime = time.time() - t0 | |
print('pushing to %s took %1.3f s' % (heap.__class__.__module__, etime)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment