Last active
July 11, 2023 12:15
-
-
Save mumbleskates/0ef75bf3f25d0faeecc73ddb9373ea75 to your computer and use it in GitHub Desktop.
Pure Python N-ary Heap implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
class NaryHeap(object): | |
"""implements an n-ary heap""" | |
def __init__(self, items=(), *, n=2, direction=min): | |
""" | |
create a new heap | |
items must be an iterable, n must be a positive integer, and direction must be | |
min or max. | |
""" | |
if not isinstance(n, int) or n < 1: | |
raise ValueError("Degree of N-ary heap (n) must be an integer >= 1") | |
self._n = n | |
self._list = [] | |
if direction not in (min, max): | |
raise ValueError("Direction must be min or max") | |
self._direction = direction | |
for item in items: | |
self.push(item) | |
def _parent_ix(self, ix): | |
"""returns the parent index of a given node""" | |
return (ix - 1) // self._n | |
def _children_ix(self, ix): | |
"""returns all valid child indices of a given index""" | |
first_child = ix * self._n + 1 | |
return range(first_child, min(first_child + self._n, len(self._list))) | |
def _sift_towards_root(self, ix): | |
current = self._list[ix] | |
while ix: # stop when ix is 0, the root | |
parent_ix = self._parent_ix(ix) | |
parent = self._list[parent_ix] | |
# note that we want to put the option that represents the least work first, | |
# because min() and max() return the earliest item if all are equal | |
if self._direction(parent, current) is parent: | |
return # we are done sorting this item | |
# continue sorting towards root | |
self._list[parent_ix], self._list[ix], ix = current, parent, parent_ix | |
def _sift_towards_leaf(self, ix): | |
current = self._list[ix] | |
while True: | |
children = tuple((self._list[x], x) for x in self._children_ix(ix)) | |
if not children: | |
return # we are done, item is now a leaf | |
# get the most parental child | |
candidate, candidate_ix = self._direction(children) | |
# again, option representing the least work comes first | |
if self._direction(current, candidate) is current: | |
return # we are done sorting this item | |
# continue sorting towards leaf | |
self._list[candidate_ix], self._list[ix], ix = current, candidate, candidate_ix | |
def __len__(self): | |
return len(self._list) | |
def push(self, item): | |
"""Insert an item into the heap""" | |
self._list.append(item) | |
self._sift_towards_root(len(self._list) - 1) | |
def peek(self): | |
"""Show the next item in the heap""" | |
return self._list[0] | |
def pop(self): | |
"""Remove and return the next item in the heap""" | |
result, self._list[0] = self._list[0], self._list[-1] | |
self._list.pop() | |
if self._list: | |
self._sift_towards_leaf(0) | |
return result |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
from functools import partial | |
from itertools import product | |
from random import choice | |
import pytest | |
from heap import NaryHeap | |
TEST_ITEMS = [ | |
[], | |
[1], | |
list(range(10)), | |
list(range(123)), | |
list(reversed(range(123))), | |
[choice(range(123)) for _ in range(123)], | |
[1] * 50, | |
"this string is also an iterable it turns out", | |
(34, 0, 77, 95, 21, 8009, 788324), | |
] | |
TEST_DEGREES = range(1, 11) | |
TEST_DIRECTIONS = (min, max) | |
@pytest.fixture(scope='session', params=product(TEST_DEGREES, TEST_DIRECTIONS)) | |
def factory(request): | |
n, direction = request.param | |
return partial(NaryHeap, n=n, direction=direction) | |
def assert_invariant(heap): | |
# make sure that the heap invariant is maintained | |
for i, item in enumerate(heap._list): | |
if i == 0: | |
continue # no need to test the root | |
parent = heap._list[(i - 1) // heap._n] | |
assert heap._direction(parent, item) is parent | |
@pytest.mark.parametrize('items', TEST_ITEMS) | |
def test_init(items, factory): | |
heap = factory(items) | |
assert_invariant(heap) | |
@pytest.mark.parametrize('items', TEST_ITEMS) | |
def test_push(items, factory): | |
heap = factory() | |
for item in items: | |
heap.push(item) | |
assert_invariant(heap) | |
@pytest.mark.parametrize('items', TEST_ITEMS) | |
def test_pop(items, factory): | |
heap = factory(items) | |
# assert that the items come out in the expected order | |
expected = sorted(items, reverse=(heap._direction is max)) | |
for expected_item in expected: | |
assert heap.pop() == expected_item | |
with pytest.raises(IndexError): | |
heap.pop() # heap should be empty now | |
@pytest.mark.parametrize('items', TEST_ITEMS) | |
def test_peek(items, factory): | |
# this test is surely overbuilt | |
heap = factory(items) | |
while heap: | |
assert heap.peek() is heap.pop() | |
with pytest.raises(IndexError): | |
heap.peek() # heap should be empty now | |
def test_bad_direction(): | |
with pytest.raises(ValueError): | |
NaryHeap(direction=5) # not min or max | |
BAD_DEGREES = [0, -5, "string", 4.4, 2.0] | |
@pytest.mark.parametrize('degree', BAD_DEGREES) | |
def test_bad_degree(degree): | |
with pytest.raises(ValueError): | |
NaryHeap(n=degree) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment