Last active
August 22, 2024 07:02
-
-
Save sohang3112/f9cbd71fcabaf70855b1f5261e7db5e7 to your computer and use it in GitHub Desktop.
Efficient set of integers, maintained using bit shift operations on an integer.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from typing import Iterable | |
class IntSet: | |
""" | |
Efficiently store integers in a set - internally uses bit shift operations. | |
NOTE: len() isn't supported - instead cast to list first: len(list(intset)) | |
>>> IntSet(0b10011) # Set of all the bits which are set to 1 (using 1-based bit indexing) | |
IntSet([1, 2, 5]) | |
>>> s = IntSet().add_from_iterable([1,2,7,1]) | |
>>> list(s) # duplicates in input iterable were removed | |
[1, 2, 7] | |
>>> s.add(8) | |
>>> s | |
IntSet([1, 2, 7, 8]) | |
>>> s.discard(1) | |
>>> 1 in s | |
False | |
>>> s2 = IntSet().add_from_iterable([2,3,5,7]) | |
>>> s.union(s2) | |
IntSet([2, 3, 5, 7, 8]) | |
>>> s.intersection(s2) | |
IntSet([2, 7]) | |
>>> s.difference(s2) | |
IntSet([8]) | |
>>> IntSet().add_from_iterable(range(1,6)).has_first_n(5) | |
True | |
""" | |
def __init__(self, bits: int = 0): | |
self._added = bits | |
def __bool__(self) -> int: | |
return self._added != 0 | |
def __contains__(self, x: int) -> bool: | |
return 1 <= x and (self._added >> (x - 1)) & 1 == 1 | |
def __iter__(self): | |
n = self._added | |
i = 1 | |
while n > 0: | |
if n & 1: | |
yield i | |
i += 1 | |
n //= 2 | |
def __repr__(self) -> str: | |
return "{}({})".format(type(self).__name__, list(self)) | |
def has_first_n(self, n: int) -> bool: | |
"""Does set have all of 1,2..n ?""" | |
return self._added == (1 << n) - 1 | |
def add(self, x: int) -> None: | |
assert x > 0, f"{x} cannot be added because it is negative or 0" | |
self._added |= 1 << (x - 1) | |
def add_from_iterable(self, elems: Iterable[int]) -> IntSet: | |
for x in elems: | |
self.add(x) | |
return self | |
def discard(self, x: int) -> None: | |
"""Remove element from set, if present""" | |
self._added &= ~(1 << (x - 1)) | |
def union(self, intset: IntSet) -> IntSet: | |
return IntSet(self._added | intset._added) | |
def intersection(self, intset: IntSet) -> IntSet: | |
return IntSet(self._added & intset._added) | |
def difference(self, intset: IntSet) -> IntSet: | |
return IntSet(self._added & ~intset._added) | |
if __name__ == "__main__": | |
import doctest | |
doctest.testmod() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment