Last active
November 25, 2022 03:18
-
-
Save parvezmrobin/771ee424044676b1fea4fbbbe4665cd4 to your computer and use it in GitHub Desktop.
Simply and efficient python implementation of bit-array
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
class BitArray(bytearray): | |
mask_for = [2**i for i in range(8)] | |
inverse_mask_for = [256 - 2**i for i in range(8)] | |
valid_values = (0, 1) | |
def __init__(self, source, *args, **kwargs) -> None: | |
if isinstance(source, int): | |
source = source/8 | |
if source % 1 != 0: | |
source = int(source) + 1 | |
super().__init__(source, *args, **kwargs) | |
def fill(self, val): | |
assert val in self.valid_values | |
val = 0 if val == 0 else 255 | |
for byte_idx in range(len(self)): | |
super().__setitem__(byte_idx, val) | |
return self | |
def _assert_key(self, key): | |
if not isinstance(key, int) and not isinstance(key, slice): | |
raise KeyError(f'{self.__name__} only support int and slices as index') | |
if isinstance(key, int): | |
if key < 0: | |
raise KeyError('Only positive keys are supported') | |
if key >= len(self) * 8: | |
raise IndexError(key) | |
@staticmethod | |
def _get_byte_index_and_offset(key: slice): | |
# 7th bit in 0th byte, 8/9th bit in 1st byte | |
start_byte_idx = key.start // 8 | |
# 23rd bit in 2nd byte, 24/25th bit in 3rd byte | |
# also exclude the last index | |
stop_byte_idx = ((key.stop - 1) // 8) + 1 | |
# if key.start is 10, then ignore first two values from first byte | |
start_offset = key.start % 8 | |
# if key.stop is 19, then ignore last 5 values from last byte | |
stop_offset = 8 - (key.stop % 8) | |
return start_byte_idx, start_offset, stop_byte_idx, stop_offset | |
def __getitem__(self, key) -> int | list[int]: | |
if isinstance(key, slice): | |
start_byte_idx, start_offset, stop_byte_idx, stop_offset = self._get_byte_index_and_offset(key) | |
byte_list = super().__getitem__(slice(start_byte_idx, stop_byte_idx)) | |
byte_value_list = [ | |
int(bool(byte & self.mask_for[i])) | |
for byte in byte_list | |
for i in range(8) | |
] | |
return byte_value_list[start_offset: -stop_offset:key.step] | |
self._assert_key(key) | |
byte_idx = key // 8 | |
byte = super().__getitem__(byte_idx) | |
idx = key % 8 | |
val = byte & self.mask_for[idx] | |
return int(bool(val)) | |
def __setitem__(self, key, val): | |
self._assert_key(key) | |
if isinstance(key, slice): | |
val_iter = iter(val) | |
start_byte_idx, start_offset, stop_byte_idx, stop_offset = self._get_byte_index_and_offset(key) | |
byte_list = super().__getitem__(slice(start_byte_idx, stop_byte_idx)) | |
stop_offset = (stop_byte_idx - start_byte_idx) * 8 - stop_offset | |
for byte_idx in range(len(byte_list)): | |
for bit_idx in range(8): | |
val_idx = byte_idx * 8 + bit_idx | |
if val_idx < start_offset or val_idx >= stop_offset: | |
continue | |
if key.step is not None and (val_idx - start_offset) % key.step != 0: | |
continue | |
next_val = next(val_iter) | |
assert next_val in self.valid_values | |
if next_val == 0: | |
byte_list[byte_idx] = byte_list[byte_idx] & self.inverse_mask_for[bit_idx] | |
else: | |
byte_list[byte_idx] = byte_list[byte_idx] | self.mask_for[bit_idx] | |
super().__setitem__(slice(start_byte_idx, stop_byte_idx), byte_list) | |
return | |
if val not in self.valid_values: | |
raise ValueError(f"You want to put {val} in a BitArray!") | |
byte_idx = key // 8 | |
byte = super().__getitem__(byte_idx) | |
bit_idx = key % 8 | |
if val == 0: | |
new_byte = byte & self.inverse_mask_for[bit_idx] | |
else: | |
new_byte = byte | self.mask_for[bit_idx] | |
super().__setitem__(byte_idx, new_byte) | |
array = BitArray(20) | |
array[1] = 1 | |
array[3] = 1 | |
array[11] = 1 | |
print(array[1:13:2]) # [1, 1, 0, 0, 0, 1] | |
array[1:7] = [1] * 6 | |
print(array[1:13]) # [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0] | |
array[11:14:2] = [1,1,1,1] | |
print(array[1:15]) #[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment