Skip to content

Instantly share code, notes, and snippets.

@azengard
Forked from henriquebastos/adt.py
Created July 4, 2020 21:08
Show Gist options
  • Save azengard/f3f0f07dca6921d8a7581d41a5c06150 to your computer and use it in GitHub Desktop.
Save azengard/f3f0f07dca6921d8a7581d41a5c06150 to your computer and use it in GitHub Desktop.
import pytest
class IntervalMap:
def __init__(self):
self.limits = []
self.map = {}
def __setitem__(self, upper_bound, value):
self.limits.append(upper_bound)
self.limits.sort()
self.map[upper_bound] = value
def get(self, index):
if index >= self.limits[-1]:
raise KeyError
for upper_bound in self.limits:
if index < upper_bound:
break
return self.map[upper_bound]
def test_interval_map():
m = IntervalMap()
m[15] = 'quinze'
m[10] = 'dez'
m[5] = 'cinco'
assert m.get(0) == 'cinco'
assert m.get(4) == 'cinco'
assert m.get(5) == 'dez'
assert m.get(9) == 'dez'
assert m.get(10) == 'quinze'
assert m.get(14) == 'quinze'
with pytest.raises(KeyError):
m.get(15)
from itertools import zip_longest
class Bits(bytes):
CHUNK = 8
def __new__(cls, n):
return super().__new__(cls, cls.number_to_bits(n))
@staticmethod
def number_to_bits(n, size=CHUNK):
length = Bits.how_many_bytes(n, size=size)
mask = 2 ** size - 1
offset = lambda i: i * size
bit_slice = lambda n, i: (n & (mask << offset(i))) >> offset(i)
return (bit_slice(n, i) for i in range(length))
def to_number(self):
return sum(byte_ * (2 ** (8 * i))
for i, byte_ in enumerate(super().__iter__()))
@staticmethod
def how_many_bytes(value, size=CHUNK):
l = 1
limit = (2 ** size) - 1
while value > limit:
value >>= size
l += 1
return l
def __repr__(self):
return f'{self.to_number()}'
def __iter__(self):
if self == bytes([0]):
yield 0
return
size = self.CHUNK
last_byte = len(self) - 1
for byte_number, value in enumerate(super().__iter__()):
count = 0
while value:
yield value & 1
value >>= 1
count += 1
if byte_number < last_byte:
yield from (0 for _ in range(size - count))
def __lshift__(self, other):
return Bits(self.to_number() << other.to_number())
class Byte(Bits):
SIZE = 1
def __new__(cls, n):
if n >= cls.upper_bound():
raise ValueError(f'{n}. Byte must be in range(0, {cls.upper_bound()})')
return super().__new__(cls, n)
@classmethod
def upper_bound(cls):
return 2 ** (cls.SIZE * cls.CHUNK)
class Word(Byte):
SIZE = 2
class Tribyte(Byte):
SIZE = 3
class DoubleWord(Byte):
SIZE = 4
def fulladder(a, b, cin):
sum_ = (a ^ b) ^ cin
cout = a & b | cin & (a ^ b)
return sum_, cout
def adder(n, m):
acc = sum_ = cout = cin = 0
for i, (a, b) in enumerate(zip_longest(n, m, fillvalue=0)):
sum_, cout = fulladder(a, b, cin)
acc |= sum_ << i
cin = cout
acc |= cout << (i + 1)
return acc
def multiplier(n, m):
acc = Bits(0)
shifter = 0
for bit in m:
if bit == 1:
acc = Bits(adder(acc, n << Bits(shifter)))
shifter += 1
return acc.to_number()
def test_how_many_bytes():
assert Bits.how_many_bytes(0) == 1
assert Bits.how_many_bytes(1) == 1
assert Bits.how_many_bytes(255) == 1
assert Bits.how_many_bytes(256) == 2
assert Bits.how_many_bytes(65535) == 2
assert Bits.how_many_bytes(65536) == 3
def test_number_to_bits():
assert list(Bits.number_to_bits(0)) == [0]
assert list(Bits.number_to_bits(1)) == [1]
assert list(Bits.number_to_bits(255)) == [255]
assert list(Bits.number_to_bits(256)) == [0, 1]
assert list(Bits.number_to_bits(65535)) == [255, 255]
assert list(Bits.number_to_bits(65536)) == [0, 0, 1]
def test_bits_big_endian_layout():
assert Bits(1) == bytes([1])
assert Bits(255) == bytes([255])
assert Bits(256) == bytes([0, 1])
assert Bits(65535) == bytes([255, 255])
assert Bits(65536) == bytes([0, 0, 1])
assert Bits(131_070) == bytes([0xfe, 0xff, 0x1])
def test_bits_iter():
assert list(Bits(0)) == [0]
assert list(Bits(1)) == [1]
assert list(Bits(2)) == [0, 1]
assert list(Bits(128)) == [0, 0, 0, 0, 0, 0, 0, 1]
assert list(Bits(255)) == [1, 1, 1, 1, 1, 1, 1, 1]
assert list(Bits(257)) == [1, 0, 0, 0, 0, 0, 0, 0, 1]
assert list(Bits(256)) == [0, 0, 0, 0, 0, 0, 0, 0, 1]
assert list(Bits(131_070)) == [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
def test_bits_to_number():
assert Bits(257).to_number() == 257
def test_bits_shift_left():
assert Bits(1) << Bits(1) == Bits(2)
def test_adder():
assert adder(Bits(0), Bits(0)) == 0
assert adder(Bits(1), Bits(0)) == 1
assert adder(Bits(0), Bits(1)) == 1
assert adder(Bits(1), Bits(1)) == 2
assert adder(Bits(5), Bits(12)) == 17
assert adder(Bits(255), Bits(1)) == 256
def test_multiplier():
assert multiplier(Bits(0), Bits(0)) == 0
assert multiplier(Bits(1), Bits(0)) == 0
assert multiplier(Bits(0), Bits(1)) == 0
assert multiplier(Bits(1), Bits(1)) == 1
assert multiplier(Bits(2), Bits(3)) == 6
assert multiplier(Bits(127), Bits(2)) == 254
from abc import ABC
import pytest
from adt import IntervalMap
from binary import Byte, adder, multiplier, Word, Tribyte, DoubleWord
class Integer(ABC):
STORAGE = bytes
types = IntervalMap()
def __init__(self, n):
self.value = self.STORAGE(n)
def __add__(self, other):
return self.factory(adder(self.value, other.value))
def __eq__(self, other):
return self.value == other.value
def __mul__(self, other):
return self.factory(multiplier(self.value, other.value))
def __repr__(self):
return f'{self.__class__.__name__}({self.value!r})'
@classmethod
def factory(cls, n):
klass = cls.types.get(n)
return klass(n)
@classmethod
def register(cls, upper_bound, klass):
cls.types[upper_bound] = klass
def Int(n):
return Integer.factory(n)
class Int8(Integer):
STORAGE = Byte
class Int16(Integer):
STORAGE = Word
class Int24(Integer):
STORAGE = Tribyte
class Int32(Integer):
STORAGE = DoubleWord
Integer.register(256, Int8)
Integer.register(65536, Int16)
Integer.register(16_777_216, Int24)
Integer.register(4_294_967_296, Int32)
def test_int8():
assert isinstance(Int8(0), Int8)
assert isinstance(Int8(255), Int8)
with pytest.raises(ValueError):
Int8(256)
assert Int8(0) + Int8(0) == Int8(0)
assert Int8(1) + Int8(1) == Int8(2)
assert Int8(254) + Int8(1) == Int8(255)
assert Int8(127) * Int8(2) == Int8(254)
def test_int16():
assert isinstance(Int16(0), Int16)
assert isinstance(Int16(65535), Int16)
with pytest.raises(ValueError):
Int16(65536)
assert Int16(0) + Int16(0) == Int16(0)
assert Int16(255) + Int16(1) == Int16(256)
assert Int16(256) * Int16(2) == Int16(512)
def test_int24():
assert isinstance(Int24(0), Int24)
assert isinstance(Int24(16_777_215), Int24)
with pytest.raises(ValueError):
Int24(16_777_216)
assert Int24(0) + Int24(0) == Int24(0)
assert Int24(65535) + Int24(1) == Int24(65536)
assert Int24(65536) * Int24(2) == Int24(131_072)
def test_int32():
assert isinstance(Int32(0), Int32)
assert isinstance(Int32(4_294_967_295), Int32)
with pytest.raises(ValueError):
Int32(4_294_967_296)
assert Int32(0) + Int32(0) == Int32(0)
assert Int32(16_777_215) + Int32(1) == Int32(16_777_216)
assert Int32(16_777_216) * Int32(2) == Int32(33_554_432)
def test_scale_up():
assert Int8(255) + Int8(1) == Int16(256)
assert Int8(128) * Int8(2) == Int16(256)
assert Int16(65535) + Int16(1) == Int24(65536)
assert Int16(32_768) * Int16(2) == Int24(65536)
assert Int24(16_777_215) + Int24(1) == Int32(16_777_216)
assert Int24(8_388_608) * Int24(2) == Int32(16_777_216)
def test_scale_down():
assert isinstance(Int8(254) + Int8(1), Int8)
assert isinstance(Int16(254) + Int16(1), Int8)
assert isinstance(Int24(65534) + Int24(1), Int16)
assert isinstance(Int32(16_777_214) + Int32(1), Int24)
[pytest]
python_files = *.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment