Skip to content

Instantly share code, notes, and snippets.

@pepyakin
Created February 5, 2025 14:04
Show Gist options
  • Save pepyakin/da532804b78ff2506047fcbf2e3ef84c to your computer and use it in GitHub Desktop.
Save pepyakin/da532804b78ff2506047fcbf2e3ef84c to your computer and use it in GitHub Desktop.
EIP 7864 implementation
from __future__ import annotations
import blake3
#
# Constants and helper functions
#
# These constants come from the specification.
BASIC_DATA_LEAF_KEY: int = 0
CODE_HASH_LEAF_KEY: int = 1
HEADER_STORAGE_OFFSET: int = 64
CODE_OFFSET: int = 128
STEM_SUBTREE_WIDTH: int = 256
MAIN_STORAGE_OFFSET: int = 256 ** 31
# For code chunkification
PUSH_OFFSET = 0x5F # 95 in decimal
PUSH1 = PUSH_OFFSET + 1
PUSH32 = PUSH_OFFSET + 32
def zero32() -> bytes:
"""Return 32 zero-bytes."""
return b"\x00" * 32
def zero64() -> bytes:
"""Return 64 zero-bytes."""
return b"\x00" * 64
def blake3_hash(data: bytes) -> bytes:
"""Compute BLAKE3 hash."""
return blake3.blake3(data).digest()
def hash_32_or_64(data: bytes | None) -> bytes:
"""
Hash rule specified in the EIP:
- If data is None or 64 bytes of zero, return 32 zero-bytes.
- Else data must be either 32 or 64 bytes, and we return blake3(data).
"""
if data is None or data == zero64():
return zero32()
if len(data) not in (32, 64):
raise ValueError("data must be 32 or 64 bytes if not None or zero64().")
return blake3_hash(data)
def bytes_to_bits(b: bytes) -> list[int]:
"""Convert a bytes object to a list of bits (most significant bit first)."""
return [
(byte >> (7 - bit_index)) & 1
for byte in b
for bit_index in range(8)
]
def bits_to_bytes(bits: list[int]) -> bytes:
"""Convert a list of bits (MSB first) back to bytes."""
if len(bits) % 8 != 0:
raise ValueError("Number of bits must be a multiple of 8.")
out = bytearray(len(bits) // 8)
for i, bit in enumerate(bits):
byte_idx = i // 8
bit_idx = 7 - (i % 8)
out[byte_idx] |= (bit << bit_idx)
return bytes(out)
#
# Node definitions
#
class InternalNode:
"""
An internal node has two children: 'left' and 'right',
each of which may be another InternalNode, a StemNode, or None (empty).
"""
def __init__(self) -> None:
self.left: InternalNode | StemNode | None = None
self.right: InternalNode | StemNode | None = None
class StemNode:
"""
A stem node stores a 31-byte 'stem' and an array of 256 possible 32-byte values.
In the EIP text, these 256 possible values act like 256 "leaves".
"""
def __init__(self, stem: bytes) -> None:
if len(stem) != 31:
raise ValueError("Stem must be 31 bytes.")
self.stem: bytes = stem
# Each entry can be None (empty) or a 32-byte value.
self.values: list[bytes | None] = [None] * 256
def set_value(self, index: int, value: bytes) -> None:
if len(value) != 32:
raise ValueError("Value must be 32 bytes.")
self.values[index] = value
#
# The main BinaryTree data structure
#
class BinaryTree:
def __init__(self) -> None:
self.root: InternalNode | StemNode | None = None
def insert(self, key: bytes, value: bytes) -> None:
"""
Insert a new key -> value into the tree. The key is 32 bytes:
- The first 31 bytes (key[:31]) define the 'stem' path in the tree.
- The last byte (key[31]) is the subindex in that stem's 256-value array.
"""
if len(key) != 32:
raise ValueError("key must be 32 bytes.")
if len(value) != 32:
raise ValueError("value must be 32 bytes.")
stem = key[:31]
subindex = key[31]
if self.root is None:
# The tree is empty, so just create a StemNode
self.root = StemNode(stem)
self.root.set_value(subindex, value)
return
# Otherwise, recursively insert into the tree:
self.root = self._insert(self.root, stem, subindex, value, depth=0)
def _insert(
self,
node: InternalNode | StemNode | None,
stem: bytes,
subindex: int,
value: bytes,
depth: int
) -> InternalNode | StemNode:
"""
Recursive helper to insert into the tree.
The bit of 'stem' used at level 'depth' determines going left (0) or right (1).
"""
if depth >= 248:
raise RuntimeError("Depth must be less than 248")
if node is None:
# If empty, create a new StemNode
new_node = StemNode(stem)
new_node.set_value(subindex, value)
return new_node
if isinstance(node, StemNode):
# If we reached a StemNode:
if node.stem == stem:
# Same stem: just update the subindex
node.set_value(subindex, value)
return node
else:
# Different stems: we need to split them out with an InternalNode
stem_bits = bytes_to_bits(stem)
existing_stem_bits = bytes_to_bits(node.stem)
return self._split_leaf(node, stem_bits, existing_stem_bits, subindex, value, depth)
else:
# node is an InternalNode
stem_bits = bytes_to_bits(stem)
bit = stem_bits[depth]
if bit == 0:
node.left = self._insert(node.left, stem, subindex, value, depth+1)
else:
node.right = self._insert(node.right, stem, subindex, value, depth+1)
return node
def _split_leaf(
self,
leaf: StemNode,
stem_bits: list[int],
existing_stem_bits: list[int],
subindex: int,
value: bytes,
depth: int
) -> InternalNode:
"""
We have encountered two stems that share a path up to 'depth', but differ at 'depth'.
Create one or more internal nodes until we have "split" them properly.
"""
if stem_bits[depth] == existing_stem_bits[depth]:
# They match at this bit, so we need another internal level deeper
new_internal = InternalNode()
bit = stem_bits[depth]
if bit == 0:
new_internal.left = self._split_leaf(
leaf, stem_bits, existing_stem_bits, subindex, value, depth+1
)
else:
new_internal.right = self._split_leaf(
leaf, stem_bits, existing_stem_bits, subindex, value, depth+1
)
return new_internal
else:
# They differ exactly at this bit
new_internal = InternalNode()
bit = stem_bits[depth]
new_stem = bits_to_bytes(stem_bits)
new_node = StemNode(new_stem)
new_node.set_value(subindex, value)
if bit == 0:
new_internal.left = new_node
new_internal.right = leaf
else:
new_internal.right = new_node
new_internal.left = leaf
return new_internal
#
# Merkleization
#
def merkelize(self) -> bytes:
"""
Compute the overall Merkle root of the tree using the rules from the EIP:
- internal_node_hash = hash(left_hash || right_hash)
- stem_node_hash = hash(stem || 0x00 || hash_of_values)
- a subindex value = leaf_node_hash = hash(value) (with special rule for zero64)
- empty_node_hash = zero32()
"""
def _merkelize(node: InternalNode | StemNode | None) -> bytes:
if node is None:
return zero32()
if isinstance(node, InternalNode):
left_hash = _merkelize(node.left)
right_hash = _merkelize(node.right)
return hash_32_or_64(left_hash + right_hash)
# node is a StemNode
# For its 256 sub-values, each becomes hash(value).
# Then these 256 hashes are combined pairwise until we get 1 hash.
hashes = [hash_32_or_64(v) for v in node.values]
while len(hashes) > 1:
new_level = []
for i in range(0, len(hashes), 2):
new_level.append(hash_32_or_64(hashes[i] + hashes[i + 1]))
hashes = new_level
# Now we have exactly 1 hash for the 256 sub-values
subtree_hash = hashes[0]
# Combine with the stem to get the final node hash
return hash_32_or_64(node.stem + b"\x00" + subtree_hash)
return _merkelize(self.root)
#
# Code chunkification
#
def chunkify_code(code: bytes) -> list[bytes]:
"""
Split code into 31-byte segments. Each chunk is 32 bytes:
- The first byte is the "leading pushdata count" (how many of the next bytes are pushdata).
- The next 31 bytes are the actual chunk of code.
If the code length is not a multiple of 31, it is zero-padded at the end.
"""
# Pad code to multiple of 31
remainder = len(code) % 31
if remainder != 0:
code += b"\x00" * (31 - remainder)
# Precompute pushdata bytes (how many bytes remain "push" after each position).
bytes_to_exec_data = [0] * (len(code) + 32)
pos = 0
while pos < len(code):
opcode = code[pos]
if PUSH1 <= opcode <= PUSH32:
pushdata_bytes = opcode - PUSH_OFFSET
else:
pushdata_bytes = 0
pos += 1
for i in range(pushdata_bytes):
bytes_to_exec_data[pos + i] = pushdata_bytes - i
pos += pushdata_bytes
# Build the chunks
chunks = []
for start in range(0, len(code), 31):
chunk_pushdata = min(bytes_to_exec_data[start], 31)
chunk_data = code[start : start + 31]
chunk = bytes([chunk_pushdata]) + chunk_data
chunks.append(chunk)
return chunks
#
# Key derivation for the single unified tree
#
def tree_hash(inp: bytes) -> bytes:
"""
As per the EIP, 'tree_hash(address + index)' is used
to derive the 31-byte stem for an (address, index) pair.
"""
return blake3_hash(inp)
def get_tree_key(address32: bytes, tree_index: int, sub_index: int) -> bytes:
"""
Derive a 32-byte key for the tree:
- The first 31 bytes come from tree_hash(address32 + tree_index.to_bytes(32, "little"))[:31]
- The last byte is sub_index.
"""
if len(address32) != 32:
raise ValueError("address32 must be 32 bytes.")
# This is just an example from the EIP text.
tmp = address32 + tree_index.to_bytes(32, "little")
stem = tree_hash(tmp)[:31]
return stem + bytes([sub_index])
def get_tree_key_for_basic_data(address32: bytes) -> bytes:
"""
Where the account's version, nonce, balance, code_size are stored in one 32-byte leaf.
"""
return get_tree_key(address32, 0, BASIC_DATA_LEAF_KEY)
def get_tree_key_for_code_hash(address32: bytes) -> bytes:
"""Where the account's code_hash is stored in one 32-byte leaf."""
return get_tree_key(address32, 0, CODE_HASH_LEAF_KEY)
def get_tree_key_for_code_chunk(address32: bytes, chunk_id: int) -> bytes:
"""
Each 31-byte chunk of code is stored at an offset of CODE_OFFSET + chunk_id,
grouped in subtrees of size STEM_SUBTREE_WIDTH = 256.
"""
pos = CODE_OFFSET + chunk_id
return get_tree_key(
address32,
pos // STEM_SUBTREE_WIDTH,
pos % STEM_SUBTREE_WIDTH
)
def get_tree_key_for_storage_slot(address32: bytes, storage_key: int) -> bytes:
"""
For the first 64 storage slots, we reuse the same stem as the account basic data.
For subsequent slots, we break it into groups of 256 in the single binary tree.
"""
# If the storage_key is within the first 64 (minus the 2 we used? The EIP text
# references the offset 64 for the "header" storage).
# This example follows the table in the spec carefully.
if storage_key < (CODE_OFFSET - HEADER_STORAGE_OFFSET):
pos = HEADER_STORAGE_OFFSET + storage_key
else:
# Large indexes go in the main "big offset" subtree
pos = MAIN_STORAGE_OFFSET + storage_key
return get_tree_key(
address32,
pos // STEM_SUBTREE_WIDTH,
pos % STEM_SUBTREE_WIDTH
)
def old_style_address_to_address32(address20: bytes) -> bytes:
"""
Convert an old 20-byte address to a 32-byte address by zero-padding on the left.
"""
if len(address20) != 20:
raise ValueError("Expected 20-byte address.")
return b"\x00" * 12 + address20
#
# Example usage
#
if __name__ == "__main__":
# Create the tree
tree = BinaryTree()
# Suppose we have an old 20-byte address
address_20 = bytes.fromhex("1234567890abcdef1234567890abcdef12345678")
address_32 = old_style_address_to_address32(address_20)
# Insert a (version=0, code_size=100, nonce=5, balance=10^9) in the basic_data leaf.
# The layout described in the EIP is:
# version (1 byte) | reserved (4 bytes) | code_size (3 bytes) | nonce (8 bytes) | balance (16 bytes)
# For simplicity, let's do code_size=100 -> 0x0064 (2 bytes), put it in the last 3 bytes region -> 0x000064
# nonce=5, balance=10^9
version = (0).to_bytes(1, "big")
reserved = b"\x00" * 4
code_size = (100).to_bytes(3, "big")
nonce = (5).to_bytes(8, "big")
balance = (10**9).to_bytes(16, "big")
basic_data_value = version + reserved + code_size + nonce + balance
if len(basic_data_value) != 32:
raise ValueError("basic_data_value must be exactly 32 bytes.")
basic_data_key = get_tree_key_for_basic_data(address_32)
tree.insert(basic_data_key, basic_data_value)
# Insert code hash
code_hash_key = get_tree_key_for_code_hash(address_32)
code_hash_value = blake3_hash(b"example code") # 32 bytes
tree.insert(code_hash_key, code_hash_value)
# Insert the first chunk of actual code
code_chunks = chunkify_code(b"example code for EVM push instructions")
chunk0_key = get_tree_key_for_code_chunk(address_32, 0)
tree.insert(chunk0_key, code_chunks[0])
# Insert storage slot 0
storage_key_0 = get_tree_key_for_storage_slot(address_32, 0)
tree.insert(storage_key_0, (999).to_bytes(32, "big"))
# Check the Merkle root
root_hash = tree.merkelize()
print("Merkle root:", root_hash.hex())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment