Last active
April 8, 2020 23:45
-
-
Save carver/382e4e84f461ff99838740e974a25a82 to your computer and use it in GitHub Desktop.
Starting an implementation of Ethereum witness format, at https://github.com/ethereum/stateless-ethereum-specs/blob/master/witness.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABC, abstractmethod | |
from enum import Enum | |
from typing import ( | |
Iterable, | |
Tuple, | |
) | |
import cbor | |
from eth_typing import Hash32 | |
from eth_utils import ( | |
ValidationError, | |
keccak, | |
to_bytes, | |
to_int, | |
) | |
import rlp | |
from trie import HexaryTrie | |
from hypothesis import ( | |
example, | |
given, | |
settings, | |
strategies as st, | |
) | |
import pytest | |
WITNESS_VERSION = 1 | |
class WitnessOpcodes(Enum): | |
LEAF = 0x0 | |
EXTENSION = 0x1 | |
BRANCH = 0x2 | |
HASH = 0x3 | |
CODE = 0x4 | |
ACCOUNT_LEAF = 0x5 | |
NEW_TRIE = 0xBB | |
class Instruction(ABC): | |
opcode: WitnessOpcodes | |
encoded_length: int | |
@abstractmethod | |
def encode(self) -> bytes: | |
... | |
@classmethod | |
@abstractmethod | |
def decode(cls, encoded: bytes) -> 'Instruction': | |
... | |
class HashInstruction(Instruction): | |
opcode = WitnessOpcodes.HASH | |
encoded_length = 33 | |
def __init__(self, hash_bytes: bytes) -> None: | |
self._hash_bytes = hash_bytes | |
@property | |
def hash_bytes(self) -> Hash32: | |
return self._hash_bytes | |
def encode(self) -> bytes: | |
return bytes((self.opcode.value,)) + self._hash_bytes | |
@classmethod | |
def decode(cls, encoded: memoryview) -> 'HashInstruction': | |
if encoded[0] != cls.opcode.value: | |
raise ValidationError(f"Cannot decode Hash instruction if opcode is {encoded[0]}") | |
elif len(encoded) < cls.encoded_length: | |
raise ValidationError(f"Cannot decode Hash instruction with only {len(encoded)} bytes") | |
else: | |
return HashInstruction(bytes(encoded[1:cls.encoded_length])) | |
# TODO implement __hash__ | |
def __eq__(self, other: object) -> bool: | |
try: | |
other_opcode = other.opcode | |
except AttributeError: | |
return False | |
try: | |
other_hash = other._hash_bytes | |
except AttributeError: | |
return False | |
return self.opcode == other_opcode and self._hash_bytes == other_hash | |
class LeafInstruction(Instruction): | |
opcode = WitnessOpcodes.LEAF | |
_encoded_length: int = None | |
def __init__(self, key: bytes, value: bytes) -> None: | |
self._key = key | |
self._value = value | |
def encode(self) -> bytes: | |
return bytes((self.opcode.value,)) + self._hash_bytes | |
@property | |
def key(self) -> bytes: | |
return self._key | |
@property | |
def value(self) -> bytes: | |
return self._value | |
@property | |
def encoded_length(self) -> int: | |
if self._encoded_length is None: | |
self._encoded_length = self.encode() | |
return self._encoded_length | |
@classmethod | |
def decode(cls, encoded: memoryview) -> 'HashInstruction': | |
if WitnessOpcodes(encoded[0]) != cls.opcode: | |
raise ValidationError(f"Cannot decode Hash instruction if opcode is {encoded[0]}") | |
key, key_length = cbor_decode(encoded[1:]) | |
value, value_length = cbor_decode(encoded[1 + key_length:]) | |
encoded_length = 1 + key_length + value_length | |
leaf = LeafInstruction(bytes(key), bytes(value)) | |
# for performance: we already know the encoded length, so avoid calculating it on-demand | |
leaf._encoded_length = encoded_length | |
return leaf | |
CBOR_MAJOR_TYPE_0 = 0 | |
CBOR_MAJOR_TYPE_2 = 2 | |
CBOR_UNSIGNED_8BIT = 24 | |
CBOR_UNSIGNED_16BIT = 25 | |
CBOR_UNSIGNED_32BIT = 26 | |
CBOR_UNSIGNED_64BIT = 27 | |
def cbor_int_val(encoded: memoryview) -> Tuple[int, int]: | |
""" | |
:return: decoded int value, and how many bytes were consumed | |
""" | |
# Cannot enforce the type here, because use use it in the byte-string type also, | |
# to extract the length | |
additional = encoded[0] & 0b11111 | |
if additional < CBOR_UNSIGNED_8BIT: | |
return additional, 1 | |
else: | |
remaining = encoded[1:] | |
if additional == CBOR_UNSIGNED_8BIT: | |
# byte_length = 1 | |
# but short-circuit for performance | |
return remaining[0], 2 | |
elif additional == CBOR_UNSIGNED_16BIT: | |
byte_length = 2 | |
elif additional == CBOR_UNSIGNED_32BIT: | |
byte_length = 4 | |
elif additional == CBOR_UNSIGNED_64BIT: | |
byte_length = 8 | |
else: | |
raise ValidationError(f"Could not decode CBOR Additional Information {additional}") | |
return to_int(remaining[:byte_length]), 1 + byte_length | |
def cbor_decode(encoded: memoryview) -> Tuple[memoryview, int]: | |
if encoded[0] >> 5 != CBOR_MAJOR_TYPE_2: | |
raise ValidationError(f"Cannot decode non-bytes type {encoded[0] >> 5} as bytes") | |
byte_length, meta_bytes_used = cbor_int_val(encoded) | |
total_bytes_used = meta_bytes_used + byte_length | |
return encoded[meta_bytes_used:total_bytes_used], total_bytes_used | |
# TODO rename int -> uint | |
def cbor_encode_int(natural_int: int) -> bytes: | |
if natural_int < CBOR_UNSIGNED_8BIT: | |
return bytes((natural_int,)) | |
else: | |
encoded_body = to_bytes(natural_int) | |
byte_length = len(encoded_body) | |
if byte_length == 1: | |
first_byte = CBOR_UNSIGNED_8BIT | |
encoded_byte_length = 1 | |
elif byte_length == 2: | |
first_byte = CBOR_UNSIGNED_16BIT | |
encoded_byte_length = 2 | |
elif byte_length <= 4: | |
first_byte = CBOR_UNSIGNED_32BIT | |
encoded_byte_length = 4 | |
elif byte_length <= 8: | |
first_byte = CBOR_UNSIGNED_64BIT | |
encoded_byte_length = 8 | |
else: | |
raise ValidationError(f"Cannot encode int {natural_int} that uses {byte_length} bytes") | |
return bytes((first_byte,)) + encoded_body.rjust(encoded_byte_length, b'\0') | |
def cbor_encode(natural_bytes: bytes) -> bytes: | |
encoded_length = cbor_encode_int(len(natural_bytes)) | |
major_type = CBOR_MAJOR_TYPE_2 << 5 | |
first_byte = major_type | encoded_length[0] | |
return bytes((first_byte,)) + encoded_length[1:] + natural_bytes | |
def decode_instruction(encoded: memoryview) -> Instruction: | |
""" | |
The returned instruction always includes information on how many bytes were 'consumed' | |
""" | |
if len(encoded) == 0: | |
raise ValidationError("Cannot decode empty byte-string") | |
else: | |
opcode = WitnessOpcodes(encoded[0]) | |
if opcode == WitnessOpcodes.HASH: | |
return HashInstruction.decode(encoded) | |
if opcode == WitnessOpcodes.LEAF: | |
return LeafInstruction.decode(encoded) | |
else: | |
raise NotImplementedError(f"Cannot decode instruction with opcode {opcode}") | |
# TODO rename to decode witness? | |
def decode_all(encoded: bytes) -> Iterable[Instruction]: | |
remaining = memoryview(encoded) | |
if remaining[0] != WITNESS_VERSION: | |
raise ValidationError(f"Cannot decode witness version {remaining[0]}") | |
else: | |
remaining = remaining[1:] | |
while remaining: | |
instruction = decode_instruction(remaining) | |
yield instruction | |
remaining = remaining[instruction.encoded_length:] | |
def encode_all(instructions: Iterable[Instruction]) -> bytes: | |
header = bytes((WITNESS_VERSION,)) | |
return header + b''.join(instruction.encode() for instruction in instructions) | |
def follow_instructions(all_instructions: Iterable[Instruction]) -> Tuple[HexaryTrie, ...]: | |
# TODO collapse instructions into root node | |
collapsed_instructions = tuple(all_instructions) | |
if len(collapsed_instructions) > 1: | |
# TODO handle multiple tries | |
raise ValidationError(f"Cannot handle multiple tries. Got: {collapsed_instructions}") | |
root = collapsed_instructions[0] | |
trie_db = {} | |
trie = HexaryTrie(trie_db) | |
if root.opcode == WitnessOpcodes.LEAF: | |
leaf = root # will probably need a cast here | |
db_value = rlp.encode([leaf.key, leaf.value]) | |
db_key = keccak(db_value) | |
trie_db[db_key] = db_value | |
trie.root_hash = db_key | |
else: | |
raise NotImplementedError(f"Cannot build trie with root {instruction.opcode}") | |
return (trie, ) | |
def test_decode_hash_instruction(): | |
expected_hash_bytes = b'01234567890123456789012345678901' | |
encoded = b'\x03' + expected_hash_bytes | |
instruction = decode_instruction(memoryview(encoded)) | |
assert instruction.opcode == WitnessOpcodes.HASH | |
assert instruction.hash_bytes == expected_hash_bytes | |
assert instruction.encoded_length == 33 | |
@given(st.binary(max_size=32, min_size=32)) | |
def test_decode_encode_hash_instruction(encoded_body): | |
encoded = b'\x03' + encoded_body | |
instruction = decode_instruction(memoryview(encoded)) | |
re_encoded = instruction.encode() | |
assert re_encoded == encoded | |
@given( | |
st.binary(max_size=32, min_size=32), | |
st.binary(max_size=32, min_size=32), | |
) | |
def test_decode_encode_multiple_hash_instructions(body1, body2): | |
encoded = b'\x01\x03' + body1 + b'\x03' + body2 | |
instructions = tuple(decode_all(encoded)) | |
assert len(instructions) == 2 | |
re_encoded = encode_all(instructions) | |
assert re_encoded == encoded | |
@given(st.binary(max_size=32, min_size=32)) | |
def test_encode_decode_hash_instruction(hash_bytes): | |
instruction = HashInstruction(hash_bytes) | |
encoded = instruction.encode() | |
decoded_instruction = decode_instruction(memoryview(encoded)) | |
assert decoded_instruction == instruction | |
@given( | |
st.binary(max_size=32, min_size=32), | |
st.binary(max_size=32, min_size=32), | |
) | |
def test_encode_decode_multiple_hash_instructions(hash_bytes1, hash_bytes2): | |
instructions = ( | |
HashInstruction(hash_bytes1), | |
HashInstruction(hash_bytes2), | |
) | |
encoded = encode_all(instructions) | |
decoded_instructions = tuple(decode_all(encoded)) | |
assert decoded_instructions == instructions | |
@given(st.binary(max_size=2**16 + 1)) | |
@settings(max_examples=2000) | |
def test_cbor_decode_bytes(original_bytes): | |
encoded_bytes = cbor.dumps(original_bytes) | |
decoded_bytes, decoded_length = cbor_decode(memoryview(encoded_bytes)) | |
assert decoded_bytes == original_bytes | |
assert decoded_length == len(encoded_bytes) | |
@given(st.binary(max_size=2**16 + 1)) | |
@settings(max_examples=2000) | |
def test_cbor_encode_bytes(original_bytes): | |
encoded_bytes = cbor_encode(original_bytes) | |
decoded_bytes = cbor.loads(encoded_bytes) | |
assert decoded_bytes == original_bytes | |
@given(st.integers(min_value=0, max_value=256**8 - 1)) | |
@settings(max_examples=2000) | |
def test_cbor_encode_int(original_int): | |
encoded_int = cbor_encode_int(original_int) | |
decoded_int = cbor.loads(encoded_int) | |
assert decoded_int == original_int | |
@pytest.mark.parametrize( | |
'witness, expected_trie, expected_trie_db', | |
( | |
( | |
bytes(( | |
# header | |
0x1, | |
# leaf | |
0x0, | |
# encoded original key b'some-key' to b'I some-key': | |
0x49, | |
0x20, | |
0x73, | |
0x6f, | |
0x6d, | |
0x65, | |
0x2d, | |
0x6b, | |
0x65, | |
0x79, | |
# encoded original val b'some-val' to b'Hsome-val': | |
0x48, | |
0x73, | |
0x6f, | |
0x6d, | |
0x65, | |
0x2d, | |
0x76, | |
0x61, | |
0x6c, | |
)), | |
{b'some-key': b'some-val'}, | |
{ | |
b'\x18\x8e\x17\x87\xc627\xde2\xed\xc1\xd8\xf1\xd3\n\x10\x12\x9d\x98\xd5S\xda\xe9\xaf\xccQT}\xa3\xb0O\x0c': | |
b'\xd3\x89 some-key\x88some-val', | |
}, | |
), | |
), | |
) | |
def test_make_one_trie_from_witness(witness, expected_trie, expected_trie_db): | |
instructions = decode_all(witness) | |
built_tries = follow_instructions(instructions) | |
assert len(built_tries) == 1 | |
built_trie = built_tries[0] | |
for key, value in expected_trie.items(): | |
assert built_trie[key] == value | |
assert expected_trie_db == built_trie.db |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment