Last active
February 26, 2022 21:51
-
-
Save nat-n/02e6882ef8b6aa3e8efe143795a4eae0 to your computer and use it in GitHub Desktop.
A gold plated implementation of an immutable bit set in python, including full test coverage.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections.abc import Set | |
from numpy import packbits | |
import math | |
import random | |
from typing import Any, Callable, Collection, Iterable, Iterator | |
class ImmutableBitSet(Set): | |
_content: bytes | |
@property | |
def size(self) -> int: | |
""" | |
The number of allocated bits in this bitset, i.e. the maximum value of | |
""" | |
return len(self._content) * 8 | |
def __init__(self, values: Collection[int] = b""): | |
if isinstance(values, bytes): | |
if len(values): | |
# Remove trailing NUL bytes | |
for index in range(len(values) - 1, -2, -1): | |
if values[index] != 0: | |
break | |
self._content = values[: index + 1] | |
else: | |
self._content = b"" | |
elif not len(values): | |
self._content = b"" | |
else: | |
values_set = set(values) | |
for value in values_set: | |
if not isinstance(value, int) or value < 0: | |
raise ValueError( | |
f"ImmutableBitSet only accepts positive intergers, not {value!r}" | |
) | |
bitvalues = [0] * (max(values_set) + 1) | |
for value in values_set: | |
bitvalues[value] = 1 | |
self._content = bytes(packbits(bitvalues)) | |
def __contains__(self, value: object) -> bool: | |
if not isinstance(value, int): | |
return False | |
try: | |
return bool(self._content[(value // 8)] & 0b1 << (8 - (value % 8) - 1)) | |
except IndexError: | |
return False | |
def __iter__(self) -> Iterator[int]: | |
index = 0 | |
cursor = 0b10000000 | |
for byte in self._content: | |
for pos in range(8): | |
if byte & cursor >> pos: | |
yield index | |
index += 1 | |
def __bytes__(self) -> bytes: | |
return self._content | |
def __len__(self) -> int: | |
""" | |
Return the number of elements in the ImmutableBitSet. | |
That is the number of bits set to 1 in self._content (aka the Hamming weight). | |
""" | |
weight = 0 | |
for byte in self._content: | |
while byte: | |
weight += 1 | |
byte &= byte - 1 | |
return weight | |
def union(self, other: Iterable[Any]) -> "ImmutableBitSet": | |
""" | |
Create a new ImmutableBitSet as the union of self and other. | |
""" | |
if isinstance(other, ImmutableBitSet): | |
return ImmutableBitSet( | |
bytes(b1 | b2 for b1, b2 in zip(self._content, bytes(other))) | |
) | |
return ImmutableBitSet(set(self) | set(other)) | |
def __or__(self, other: Set) -> "ImmutableBitSet": | |
return self.union(other) | |
def intersection(self, other: Iterable[Any]) -> "ImmutableBitSet": | |
""" | |
Create a new ImmutableBitSet as the intersection of self and other. | |
""" | |
if isinstance(other, ImmutableBitSet): | |
return ImmutableBitSet( | |
bytes(b1 & b2 for b1, b2 in zip(self._content, bytes(other))) | |
) | |
return ImmutableBitSet(set(self) & set(other)) | |
def __and__(self, other: Set) -> "ImmutableBitSet": | |
return self.intersection(other) | |
def isdisjoint(self, other: Iterable[Any]) -> bool: | |
""" | |
Return True if the set has no elements in common with other. Sets are disjoint | |
if and only if their intersection is the empty set. | |
""" | |
return not self.intersection(other) | |
def issubset(self, other: Iterable[Any]) -> bool: | |
""" | |
Test whether every element in the set is in other. | |
""" | |
if isinstance(other, ImmutableBitSet): | |
return all( | |
b1 == (b1 & b2) for b1, b2 in zip(self._content, bytes(other)) | |
) and len(other) >= len(self) | |
return set(self) <= set(other) | |
def __le__(self, other: Set) -> bool: | |
return self.issubset(other) | |
def __lt__(self, other: Set) -> bool: | |
""" | |
Test whether the set is a proper subset of other, that is, set <= other and set | |
!= other. | |
""" | |
if isinstance(other, ImmutableBitSet): | |
return ( | |
all(b1 == (b1 & b2) for b1, b2 in zip(self._content, bytes(other))) | |
and len(other) >= len(self) | |
and self != other | |
) | |
return set(self) < set(other) | |
def issuperset(self, other: Set) -> bool: | |
""" | |
Test whether every element in other is in the set. | |
""" | |
if isinstance(other, ImmutableBitSet): | |
return all( | |
b2 == (b1 & b2) for b1, b2 in zip(self._content, bytes(other)) | |
) and len(self) >= len(other) | |
return set(self) >= set(other) | |
def __ge__(self, other: Set) -> bool: | |
return self.issuperset(other) | |
def __gt__(self, other: Set) -> bool: | |
""" | |
Test whether the set is a proper superset of other, that is, set >= other and | |
set != other. | |
""" | |
if isinstance(other, ImmutableBitSet): | |
return ( | |
all(b2 == (b1 & b2) for b1, b2 in zip(self._content, bytes(other))) | |
and len(self) >= len(other) | |
and self != other | |
) | |
return set(self) > set(other) | |
def __eq__(self, other: object) -> bool: | |
return isinstance(other, ImmutableBitSet) and self._content == bytes(other) | |
def __ne__(self, other: object) -> bool: | |
return not isinstance(other, ImmutableBitSet) or self._content != bytes(other) | |
def difference(self, other: Iterable[Any]) -> "ImmutableBitSet": | |
"""Return a new ImmutableBitSet with elements in either the set or other but not both.""" | |
if isinstance(other, ImmutableBitSet): | |
return ImmutableBitSet( | |
bytes(b1 - (b1 & b2) for b1, b2 in zip(self._content, bytes(other))) | |
) | |
return ImmutableBitSet(set(self) - set(other)) | |
def __sub__(self, other: Set): | |
return self.difference(other) | |
def symmetric_difference(self, other: Iterable[Any]) -> "ImmutableBitSet": | |
"""Return a new ImmutableBitSet with elements in either the set or other but not both.""" | |
if isinstance(other, ImmutableBitSet): | |
return ImmutableBitSet( | |
bytes(b1 ^ b2 for b1, b2 in zip(self._content, bytes(other))) | |
) | |
return ImmutableBitSet(set(self) ^ set(other)) | |
def __xor__(self, other: Set): | |
return self.symmetric_difference(other) | |
def __bool__(self) -> bool: | |
return bool(self._content) | |
def copy(self) -> "ImmutableBitSet": | |
return ImmutableBitSet(self._content) | |
# | |
# Tests | |
# usage: pytest -v ./immutable_bitset.py | |
# | |
def test_immutable_bit_set_with_arbitrary_bits(): | |
""" | |
Generate 50 random sets of 20 integers and test various assumptions | |
""" | |
population = list(range(256)) | |
for _ in range(50): | |
values = random.choices(population, k=20) | |
bs = ImmutableBitSet(values) | |
assert len(bs) == len( | |
set(values) | |
), "Bitset should know how many items it contains" | |
assert ( | |
bs.size == math.ceil((max(values) + 1) / 8) * 8 | |
), "Bitset should use the optimal number of bytes" | |
for num in range(256): | |
assert (num in bs) == ( | |
num in values | |
), f"Expected Bitset to only contain given values" | |
assert set(bs) == set( | |
values | |
), f"Expected Bitset to contain the original values " | |
def test_immutable_bit_set_with_single_bits(): | |
for num in range(0, 256): | |
values = [num] | |
bs = ImmutableBitSet(values) | |
assert len(bs) == 1, "Bitset should know how many items it contains" | |
assert ( | |
bs.size == math.ceil((max(values) + 1) / 8) * 8 | |
), "Bitset should use the optimal number of bytes" | |
assert num in bs, f"Expected Bitset to contain {num}" | |
assert set(bs) == set( | |
values | |
), f"Expected Bitset to contain the original values " | |
def test_boolean_cast(): | |
assert not len(ImmutableBitSet([])) | |
assert not ImmutableBitSet([]) | |
assert ImmutableBitSet([42]) | |
def test_set_union_and_intersection(): | |
fizz = set() | |
buzz = set() | |
fizzbuzz = set() | |
for num in range(1, 256): | |
if not num % 3: | |
fizz.add(num) | |
if not num % 5: | |
buzz.add(num) | |
if not num % 3: | |
fizzbuzz.add(num) | |
fizz_bs = ImmutableBitSet(fizz) | |
buzz_bs = ImmutableBitSet(buzz) | |
fizzbuzz_bs = ImmutableBitSet(fizzbuzz) | |
# sanity checks | |
assert fizz == set(fizz_bs) | |
assert buzz == set(buzz_bs) | |
assert fizzbuzz == set(fizzbuzz_bs) | |
assert fizz_bs != buzz_bs | |
assert fizz_bs != fizzbuzz_bs | |
assert buzz_bs != fizzbuzz_bs | |
assert fizz & buzz == fizzbuzz, "Set intersection should work as expected" | |
assert fizz_bs & buzz_bs == fizzbuzz_bs, "Set intersection should work as expected" | |
assert fizz | buzz == set(fizz_bs) | set( | |
buzz_bs | |
), "Set union should work as expected" | |
assert ( | |
ImmutableBitSet(fizz) | ImmutableBitSet(buzz) == fizz_bs | buzz_bs | |
), "Set union should work as expected" | |
def test_set_comparisons(): | |
s1 = {1, 2, 3} | |
s2 = {3, 4, 5} | |
s3 = {4, 5, 6} | |
s4 = {3, 4, "cow"} | |
s5 = {33, 4, "cow"} | |
bs1 = ImmutableBitSet(s1) | |
bs2 = ImmutableBitSet(s2) | |
bs3 = ImmutableBitSet(s3) | |
assert_value_error_on_init(lambda: ImmutableBitSet(s4)) | |
assert_value_error_on_init(lambda: ImmutableBitSet(s5)) | |
assert bs1.union(s1) == bs1 | |
assert bs1.union(s2) == ImmutableBitSet(s1 | s2) | |
assert bs1.union(s3) == ImmutableBitSet(s1 | s3) | |
assert bs1.union(bs1) == bs1 | |
assert bs1.union(bs2) == ImmutableBitSet(s1 | s2) | |
assert bs1.union(bs3) == ImmutableBitSet(s1 | s3) | |
assert_value_error_on_init(lambda: bs1.union(s4)) | |
assert_value_error_on_init(lambda: bs1.union(s5)) | |
assert bs1.intersection(s1) == bs1 | |
assert bs1.intersection(s2) == ImmutableBitSet(s1 & s2) | |
assert bs1.intersection(s3) == ImmutableBitSet(s1 & s3) | |
assert bs1.intersection(s4) == ImmutableBitSet(s1 & s4) | |
assert bs1.intersection(bs1) == bs1 | |
assert bs1.intersection(bs2) == ImmutableBitSet(s1 & s2) | |
assert bs1.intersection(bs3) == ImmutableBitSet(s1 & s3) | |
assert bs1.intersection(s4) == ImmutableBitSet([3]) | |
assert bs1.intersection("lol") == ImmutableBitSet(tuple()) | |
assert bs1.isdisjoint(bs3) | |
assert not bs1.isdisjoint(bs2) | |
assert bs1.isdisjoint(s3) | |
assert not bs1.isdisjoint(s2) | |
assert not bs1.isdisjoint(s4) | |
assert bs1.isdisjoint(s5) | |
assert bs1.isdisjoint("lol") | |
assert not bs1 <= ImmutableBitSet() | |
assert bs1 <= bs1 | |
assert not bs1.issubset(ImmutableBitSet()) | |
assert not bs1.issubset({1}) | |
assert bs1.issubset(bs1) | |
assert not bs2.issubset(bs1) | |
assert bs2.issubset(bs1 | bs3) | |
assert bs2.issubset({1, 2, 3, 4, 5, 6, "goose!"}) | |
assert not bs2.issubset(bs3) | |
assert not bs2.issubset("lol") | |
assert not bs1 < ImmutableBitSet() | |
assert not bs1 < ImmutableBitSet({0, 2, 3, 4}) | |
assert not bs1 < {0, 2, 3, 4} | |
assert not bs1 < bs1 | |
assert bs1 < bs1 | bs3 | |
assert bs2 < bs1 | bs3 | |
assert bs1 >= ImmutableBitSet() | |
assert bs1 >= bs1 | |
assert bs1.issuperset(ImmutableBitSet()) | |
assert bs1.issuperset({1}) | |
assert bs1.issuperset({1, 3}) | |
assert bs1.issuperset(bs1) | |
assert not bs2.issuperset(bs1) | |
assert not bs2.issuperset(bs1 | bs3) | |
assert not bs2.issuperset(bs3) | |
assert not bs2.issuperset("lol") | |
assert bs1 > ImmutableBitSet() | |
assert bs1 > ImmutableBitSet({1, 3}) | |
assert bs1 > {1, 3} | |
assert not bs1 > ImmutableBitSet({0, 2, 3, 4}) | |
assert not bs1 > bs1 | |
assert not bs1 > bs1 | bs3 | |
assert not bs2 > bs1 | bs3 | |
assert bs1 == ImmutableBitSet({1, 2, 3}) | |
assert bs1 != ImmutableBitSet({1, 3}) | |
assert bs1 != bs2 | |
assert bs1 != bs3 | |
assert bs1 != "lol" | |
assert bs1 != s1 | |
assert not bs1 ^ bs1 | |
assert bs1 ^ bs2 == ImmutableBitSet((1, 2, 4, 5)) | |
assert not bs1.symmetric_difference(bs1) | |
assert bs1.symmetric_difference(s2) == ImmutableBitSet((1, 2, 4, 5)) | |
assert bs1.symmetric_difference(bs2) == ImmutableBitSet((1, 2, 4, 5)) | |
assert_value_error_on_init(lambda: bs1.symmetric_difference("lol")) | |
assert not bs1 - bs1 | |
assert not bs2 - (bs1 | bs3) | |
assert bs1 - ImmutableBitSet({2}) == ImmutableBitSet({1, 3}) | |
assert not bs1.difference(bs1) | |
assert not bs2 - (s1 | s3) | |
assert bs1 - {2} == ImmutableBitSet({1, 3}) | |
def test_contains_non_int(): | |
bs = ImmutableBitSet([0]) | |
assert 0 in bs | |
assert 30 not in bs | |
assert "0" not in bs | |
assert "30" not in bs | |
assert (0,) not in bs | |
def test_copy(): | |
bs = ImmutableBitSet([1, 2, 3]) | |
bs_copy = bs.copy() | |
assert bs == bs_copy | |
assert bs is not bs_copy | |
def assert_value_error_on_init(fun: Callable): # pragma: no cover | |
try: | |
fun() | |
assert False, "Instanciating a ImmutableBitSet with non int value should raise" | |
except ValueError: | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment