Skip to content

Instantly share code, notes, and snippets.

@carver
Last active April 8, 2020 23:45
Show Gist options
  • Save carver/382e4e84f461ff99838740e974a25a82 to your computer and use it in GitHub Desktop.
Save carver/382e4e84f461ff99838740e974a25a82 to your computer and use it in GitHub Desktop.
Starting an implementation of Ethereum witness format, at https://github.com/ethereum/stateless-ethereum-specs/blob/master/witness.md
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
Iterable,
Tuple,
)
import cbor
from eth_typing import Hash32
from eth_utils import (
ValidationError,
keccak,
to_bytes,
to_int,
)
import rlp
from trie import HexaryTrie
from hypothesis import (
example,
given,
settings,
strategies as st,
)
import pytest
WITNESS_VERSION = 1
class WitnessOpcodes(Enum):
LEAF = 0x0
EXTENSION = 0x1
BRANCH = 0x2
HASH = 0x3
CODE = 0x4
ACCOUNT_LEAF = 0x5
NEW_TRIE = 0xBB
class Instruction(ABC):
opcode: WitnessOpcodes
encoded_length: int
@abstractmethod
def encode(self) -> bytes:
...
@classmethod
@abstractmethod
def decode(cls, encoded: bytes) -> 'Instruction':
...
class HashInstruction(Instruction):
opcode = WitnessOpcodes.HASH
encoded_length = 33
def __init__(self, hash_bytes: bytes) -> None:
self._hash_bytes = hash_bytes
@property
def hash_bytes(self) -> Hash32:
return self._hash_bytes
def encode(self) -> bytes:
return bytes((self.opcode.value,)) + self._hash_bytes
@classmethod
def decode(cls, encoded: memoryview) -> 'HashInstruction':
if encoded[0] != cls.opcode.value:
raise ValidationError(f"Cannot decode Hash instruction if opcode is {encoded[0]}")
elif len(encoded) < cls.encoded_length:
raise ValidationError(f"Cannot decode Hash instruction with only {len(encoded)} bytes")
else:
return HashInstruction(bytes(encoded[1:cls.encoded_length]))
# TODO implement __hash__
def __eq__(self, other: object) -> bool:
try:
other_opcode = other.opcode
except AttributeError:
return False
try:
other_hash = other._hash_bytes
except AttributeError:
return False
return self.opcode == other_opcode and self._hash_bytes == other_hash
class LeafInstruction(Instruction):
opcode = WitnessOpcodes.LEAF
_encoded_length: int = None
def __init__(self, key: bytes, value: bytes) -> None:
self._key = key
self._value = value
def encode(self) -> bytes:
return bytes((self.opcode.value,)) + self._hash_bytes
@property
def key(self) -> bytes:
return self._key
@property
def value(self) -> bytes:
return self._value
@property
def encoded_length(self) -> int:
if self._encoded_length is None:
self._encoded_length = self.encode()
return self._encoded_length
@classmethod
def decode(cls, encoded: memoryview) -> 'HashInstruction':
if WitnessOpcodes(encoded[0]) != cls.opcode:
raise ValidationError(f"Cannot decode Hash instruction if opcode is {encoded[0]}")
key, key_length = cbor_decode(encoded[1:])
value, value_length = cbor_decode(encoded[1 + key_length:])
encoded_length = 1 + key_length + value_length
leaf = LeafInstruction(bytes(key), bytes(value))
# for performance: we already know the encoded length, so avoid calculating it on-demand
leaf._encoded_length = encoded_length
return leaf
CBOR_MAJOR_TYPE_0 = 0
CBOR_MAJOR_TYPE_2 = 2
CBOR_UNSIGNED_8BIT = 24
CBOR_UNSIGNED_16BIT = 25
CBOR_UNSIGNED_32BIT = 26
CBOR_UNSIGNED_64BIT = 27
def cbor_int_val(encoded: memoryview) -> Tuple[int, int]:
"""
:return: decoded int value, and how many bytes were consumed
"""
# Cannot enforce the type here, because use use it in the byte-string type also,
# to extract the length
additional = encoded[0] & 0b11111
if additional < CBOR_UNSIGNED_8BIT:
return additional, 1
else:
remaining = encoded[1:]
if additional == CBOR_UNSIGNED_8BIT:
# byte_length = 1
# but short-circuit for performance
return remaining[0], 2
elif additional == CBOR_UNSIGNED_16BIT:
byte_length = 2
elif additional == CBOR_UNSIGNED_32BIT:
byte_length = 4
elif additional == CBOR_UNSIGNED_64BIT:
byte_length = 8
else:
raise ValidationError(f"Could not decode CBOR Additional Information {additional}")
return to_int(remaining[:byte_length]), 1 + byte_length
def cbor_decode(encoded: memoryview) -> Tuple[memoryview, int]:
if encoded[0] >> 5 != CBOR_MAJOR_TYPE_2:
raise ValidationError(f"Cannot decode non-bytes type {encoded[0] >> 5} as bytes")
byte_length, meta_bytes_used = cbor_int_val(encoded)
total_bytes_used = meta_bytes_used + byte_length
return encoded[meta_bytes_used:total_bytes_used], total_bytes_used
# TODO rename int -> uint
def cbor_encode_int(natural_int: int) -> bytes:
if natural_int < CBOR_UNSIGNED_8BIT:
return bytes((natural_int,))
else:
encoded_body = to_bytes(natural_int)
byte_length = len(encoded_body)
if byte_length == 1:
first_byte = CBOR_UNSIGNED_8BIT
encoded_byte_length = 1
elif byte_length == 2:
first_byte = CBOR_UNSIGNED_16BIT
encoded_byte_length = 2
elif byte_length <= 4:
first_byte = CBOR_UNSIGNED_32BIT
encoded_byte_length = 4
elif byte_length <= 8:
first_byte = CBOR_UNSIGNED_64BIT
encoded_byte_length = 8
else:
raise ValidationError(f"Cannot encode int {natural_int} that uses {byte_length} bytes")
return bytes((first_byte,)) + encoded_body.rjust(encoded_byte_length, b'\0')
def cbor_encode(natural_bytes: bytes) -> bytes:
encoded_length = cbor_encode_int(len(natural_bytes))
major_type = CBOR_MAJOR_TYPE_2 << 5
first_byte = major_type | encoded_length[0]
return bytes((first_byte,)) + encoded_length[1:] + natural_bytes
def decode_instruction(encoded: memoryview) -> Instruction:
"""
The returned instruction always includes information on how many bytes were 'consumed'
"""
if len(encoded) == 0:
raise ValidationError("Cannot decode empty byte-string")
else:
opcode = WitnessOpcodes(encoded[0])
if opcode == WitnessOpcodes.HASH:
return HashInstruction.decode(encoded)
if opcode == WitnessOpcodes.LEAF:
return LeafInstruction.decode(encoded)
else:
raise NotImplementedError(f"Cannot decode instruction with opcode {opcode}")
# TODO rename to decode witness?
def decode_all(encoded: bytes) -> Iterable[Instruction]:
remaining = memoryview(encoded)
if remaining[0] != WITNESS_VERSION:
raise ValidationError(f"Cannot decode witness version {remaining[0]}")
else:
remaining = remaining[1:]
while remaining:
instruction = decode_instruction(remaining)
yield instruction
remaining = remaining[instruction.encoded_length:]
def encode_all(instructions: Iterable[Instruction]) -> bytes:
header = bytes((WITNESS_VERSION,))
return header + b''.join(instruction.encode() for instruction in instructions)
def follow_instructions(all_instructions: Iterable[Instruction]) -> Tuple[HexaryTrie, ...]:
# TODO collapse instructions into root node
collapsed_instructions = tuple(all_instructions)
if len(collapsed_instructions) > 1:
# TODO handle multiple tries
raise ValidationError(f"Cannot handle multiple tries. Got: {collapsed_instructions}")
root = collapsed_instructions[0]
trie_db = {}
trie = HexaryTrie(trie_db)
if root.opcode == WitnessOpcodes.LEAF:
leaf = root # will probably need a cast here
db_value = rlp.encode([leaf.key, leaf.value])
db_key = keccak(db_value)
trie_db[db_key] = db_value
trie.root_hash = db_key
else:
raise NotImplementedError(f"Cannot build trie with root {instruction.opcode}")
return (trie, )
def test_decode_hash_instruction():
expected_hash_bytes = b'01234567890123456789012345678901'
encoded = b'\x03' + expected_hash_bytes
instruction = decode_instruction(memoryview(encoded))
assert instruction.opcode == WitnessOpcodes.HASH
assert instruction.hash_bytes == expected_hash_bytes
assert instruction.encoded_length == 33
@given(st.binary(max_size=32, min_size=32))
def test_decode_encode_hash_instruction(encoded_body):
encoded = b'\x03' + encoded_body
instruction = decode_instruction(memoryview(encoded))
re_encoded = instruction.encode()
assert re_encoded == encoded
@given(
st.binary(max_size=32, min_size=32),
st.binary(max_size=32, min_size=32),
)
def test_decode_encode_multiple_hash_instructions(body1, body2):
encoded = b'\x01\x03' + body1 + b'\x03' + body2
instructions = tuple(decode_all(encoded))
assert len(instructions) == 2
re_encoded = encode_all(instructions)
assert re_encoded == encoded
@given(st.binary(max_size=32, min_size=32))
def test_encode_decode_hash_instruction(hash_bytes):
instruction = HashInstruction(hash_bytes)
encoded = instruction.encode()
decoded_instruction = decode_instruction(memoryview(encoded))
assert decoded_instruction == instruction
@given(
st.binary(max_size=32, min_size=32),
st.binary(max_size=32, min_size=32),
)
def test_encode_decode_multiple_hash_instructions(hash_bytes1, hash_bytes2):
instructions = (
HashInstruction(hash_bytes1),
HashInstruction(hash_bytes2),
)
encoded = encode_all(instructions)
decoded_instructions = tuple(decode_all(encoded))
assert decoded_instructions == instructions
@given(st.binary(max_size=2**16 + 1))
@settings(max_examples=2000)
def test_cbor_decode_bytes(original_bytes):
encoded_bytes = cbor.dumps(original_bytes)
decoded_bytes, decoded_length = cbor_decode(memoryview(encoded_bytes))
assert decoded_bytes == original_bytes
assert decoded_length == len(encoded_bytes)
@given(st.binary(max_size=2**16 + 1))
@settings(max_examples=2000)
def test_cbor_encode_bytes(original_bytes):
encoded_bytes = cbor_encode(original_bytes)
decoded_bytes = cbor.loads(encoded_bytes)
assert decoded_bytes == original_bytes
@given(st.integers(min_value=0, max_value=256**8 - 1))
@settings(max_examples=2000)
def test_cbor_encode_int(original_int):
encoded_int = cbor_encode_int(original_int)
decoded_int = cbor.loads(encoded_int)
assert decoded_int == original_int
@pytest.mark.parametrize(
'witness, expected_trie, expected_trie_db',
(
(
bytes((
# header
0x1,
# leaf
0x0,
# encoded original key b'some-key' to b'I some-key':
0x49,
0x20,
0x73,
0x6f,
0x6d,
0x65,
0x2d,
0x6b,
0x65,
0x79,
# encoded original val b'some-val' to b'Hsome-val':
0x48,
0x73,
0x6f,
0x6d,
0x65,
0x2d,
0x76,
0x61,
0x6c,
)),
{b'some-key': b'some-val'},
{
b'\x18\x8e\x17\x87\xc627\xde2\xed\xc1\xd8\xf1\xd3\n\x10\x12\x9d\x98\xd5S\xda\xe9\xaf\xccQT}\xa3\xb0O\x0c':
b'\xd3\x89 some-key\x88some-val',
},
),
),
)
def test_make_one_trie_from_witness(witness, expected_trie, expected_trie_db):
instructions = decode_all(witness)
built_tries = follow_instructions(instructions)
assert len(built_tries) == 1
built_trie = built_tries[0]
for key, value in expected_trie.items():
assert built_trie[key] == value
assert expected_trie_db == built_trie.db
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment