Skip to content

Instantly share code, notes, and snippets.

@elyase
Created February 27, 2025 11:30
Show Gist options
  • Save elyase/3ff1dea14f4c665d11e683cecae849ea to your computer and use it in GitHub Desktop.
Save elyase/3ff1dea14f4c665d11e683cecae849ea to your computer and use it in GitHub Desktop.
Example of how to fetch and join raw tables in cherry
import asyncio
import cherry_core
import polars as pl
import pyarrow as pa
import pyarrow.compute as pc
from cherry_core import ingest
from pyarrow import RecordBatch
# Configuration for teaching purposes
signature = "Transfer(address indexed src, address indexed dst, uint wad)"
contract_address = "0x833589fCD6eDb6E08f4c7C32D4f71b54bdA02913" # WETH on Base network
start_block = 20123123
end_block = 20123223
network = "base"
# Field definitions to reduce duplication
LOG_FIELDS = [
"block_number",
"transaction_index",
"log_index",
"address",
"data",
"topic0",
"topic1",
"topic2",
"topic3",
]
LOG_JOIN_FIELDS = ["block_number", "transaction_index", "log_index", "address"]
BLOCK_FIELDS = ["number", "timestamp"]
TRANSACTION_FIELDS = [
"block_number",
"transaction_index",
"hash",
"from",
"to",
"value",
"effective_gas_price",
]
def create_fields_config(log_fields, block_fields, transaction_fields):
"""Create an ingest.Fields configuration from field lists."""
# Handle the 'from' field which is a Python keyword
tx_fields_dict = {
field if field != "from" else "from_": True for field in transaction_fields
}
return ingest.Fields(
block=ingest.BlockFields(**{field: True for field in block_fields}),
log=ingest.LogFields(**{field: True for field in log_fields}),
transaction=ingest.TransactionFields(**tx_fields_dict),
)
def get_hex_columns_from_signature(
signature: str,
standard_fields: list[str] = ["address", "from", "to", "hash"],
) -> list[str]:
"""
Extract hex-encodable fields from event signature and standard Ethereum fields.
Args:
signature: Event signature (e.g., "Transfer(address indexed src, address indexed dst, uint wad)")
standard_fields: Common Ethereum fields to include
Returns:
List of field names that should be hex-encoded
"""
import re
# Extract address fields from signature and combine with standard fields
address_fields = re.findall(r"address(?:\s+indexed)?\s+(\w+)", signature)
hex_fields = set(address_fields) | set(standard_fields)
return list(hex_fields)
def get_hex_columns_from_record_batch(
rb: RecordBatch,
standard_fields: list[str] = ["address", "from", "to", "hash"],
) -> list[str]:
"""
Extract hex-encodable fields from a RecordBatch schema.
Identifies binary fields that are 20 bytes (Ethereum addresses) or
32 bytes (transaction hashes) in length.
Args:
rb: The PyArrow RecordBatch to analyze
standard_fields: Common Ethereum fields to include if they exist in the schema
Returns:
List of field names that should be hex-encoded
"""
hex_fields = set()
# Add standard fields if they exist in the schema
for field_name in standard_fields:
if field_name in rb.schema.names:
hex_fields.add(field_name)
# Check each binary field for correct length
for field in rb.schema:
if field.name in hex_fields:
continue
if pa.types.is_binary(field.type):
column = rb.column(field.name)
if len(column) == 0:
continue
# Find first non-null value
for value in column:
if value is not None and not pc.is_null(value).as_py():
# Check if length is 20 or 32 bytes
if len(value.as_py()) in (20, 32):
hex_fields.add(field.name)
break
return list(hex_fields)
def cast_decimal256_to_decimal128(rb: RecordBatch) -> RecordBatch:
"""Cast decimal256 fields in PyArrow RecordBatch to decimal128(38,0)."""
fields = [
pa.field(f.name, pa.decimal128(38, 0))
if isinstance(f.type, pa.Decimal256Type)
else f
for f in rb.schema
]
schema = pa.schema(fields)
return rb.cast(schema)
def cast_decimal256_to_timestamp(
rb: RecordBatch, timestamp_field_names=["timestamp"], unit="s"
) -> RecordBatch:
"""Convert specified decimal256 fields to timestamp in two steps."""
# Step 1: cast to int64
int_fields = [
pa.field(f.name, pa.int64()) if f.name in timestamp_field_names else f
for f in rb.schema
]
rb_int = rb.cast(pa.schema(int_fields))
# Step 2: cast int64 to timestamp
timestamp_fields = [
pa.field(f.name, pa.timestamp(unit)) if f.name in timestamp_field_names else f
for f in rb_int.schema
]
return rb_int.cast(pa.schema(timestamp_fields))
def encode_hex_columns(
rb: RecordBatch, hex_field_names=None, signature=None
) -> RecordBatch:
"""Convert binary hex fields to properly formatted hexadecimal strings with '0x' prefix.
This is particularly useful for Ethereum addresses and transaction hashes
which are stored as binary but need to be displayed as hex strings with '0x' prefix.
Args:
rb: The PyArrow RecordBatch to process
hex_field_names: Optional list of field names to encode as hex. If None, fields will be
automatically identified.
signature: Optional event signature to help identify hex fields from the schema
Returns:
A new RecordBatch with binary fields converted to hex strings
"""
# If no hex field names provided, automatically identify them
if hex_field_names is None:
if signature:
# Use the signature-based approach
hex_field_names = get_hex_columns_from_signature(
signature,
standard_fields=["address", "from", "to", "hash", "sender", "receiver"],
)
print(
f"Automatically identified hex columns from signature: {hex_field_names}"
)
else:
# Use the RecordBatch schema-based approach
hex_field_names = get_hex_columns_from_record_batch(
rb,
standard_fields=["address", "from", "to", "hash", "sender", "receiver"],
)
print(
f"Automatically identified hex columns from schema: {hex_field_names}"
)
# Filter hex_field_names to only include fields that actually exist in the RecordBatch
existing_hex_fields = [
field for field in hex_field_names if field in rb.schema.names
]
# Create a copy of the input RecordBatch
arrays = list(rb.columns)
schema_fields = list(rb.schema)
# For each field that should be converted to hex
for i, field in enumerate(schema_fields):
if field.name in existing_hex_fields:
# Replace the array with its hex-encoded version (with 0x prefix)
arrays[i] = cherry_core.prefix_hex_encode_column(rb.column(field.name))
# Create a new RecordBatch with the updated arrays
return pa.RecordBatch.from_arrays(arrays, names=rb.schema.names)
def add_derived_columns_arrow(table: pa.Table) -> pa.Table:
"""Add derived columns using PyArrow compute functions."""
# Convert wad to float64 and divide by 1e18
wad_array = table.column("wad")
wad_float = pc.cast(wad_array, pa.float64())
amount_eth = pc.divide(wad_float, 1e18)
# Add the new column to the Table
return table.append_column("amount_eth", amount_eth)
async def fetch_batches(provider: ingest.ProviderConfig):
"""Async generator to fetch batches of data from the provider."""
# Create fields configuration using our helper function
fields_config = create_fields_config(LOG_FIELDS, BLOCK_FIELDS, TRANSACTION_FIELDS)
stream = ingest.start_stream(
ingest.StreamConfig(
format=ingest.Format.EVM,
provider=provider,
query=ingest.EvmQuery(
from_block=start_block,
to_block=end_block,
logs=[
ingest.LogRequest(
address=[contract_address],
event_signatures=[signature],
)
],
fields=fields_config,
),
)
)
while True:
batch = await stream.next()
if batch is None:
break
yield batch
def process(merged: pa.Table, event_signature: str = None) -> pa.Table:
"""Apply further processing steps to the merged table.
This function applies:
1. Casting decimal256 fields to decimal128
2. Converting timestamp fields from decimal to timestamp
3. Encoding binary address and hash fields to hex strings with 0x prefix
4. Adding derived columns (amount_eth)
Args:
merged: The merged table to process
event_signature: Optional event signature to help identify hex fields
"""
# Get all batches in the table
batches = merged.to_batches()
print(f"Processing {len(batches)} batch(es) from the merged table")
# If we have an event signature, pre-identify hex columns
hex_columns = None
if event_signature:
hex_columns = get_hex_columns_from_signature(
event_signature,
standard_fields=["address", "from", "to", "hash", "sender", "receiver"],
)
print(f"Identified hex columns from signature: {hex_columns}")
# Process all batches in the table
processed_batches = []
for i, record_batch in enumerate(batches):
# Cast decimal256 fields to decimal128
record_batch = cast_decimal256_to_decimal128(record_batch)
# Convert timestamp fields from decimal to timestamp
record_batch = cast_decimal256_to_timestamp(record_batch, ["timestamp"], "s")
# Encode binary address and hash fields to hex strings with 0x prefix
record_batch = encode_hex_columns(record_batch, hex_field_names=hex_columns)
processed_batches.append(record_batch)
# Convert all processed batches back to a single Table
processed = pa.Table.from_batches(processed_batches)
# Add derived columns
return add_derived_columns_arrow(processed)
def merge(
decoded: RecordBatch,
logs: RecordBatch,
blocks: RecordBatch,
transactions: RecordBatch,
) -> pa.Table:
"""Merge raw logs, decoded events, raw blocks, and raw transactions.
This function combines the raw logs with the decoded events side by side, then joins with
raw blocks and raw transactions. Blocks table's 'number' column is renamed to 'block_number'
to match join keys.
"""
# Select only necessary columns from raw logs for joining
logs_for_join = logs.select(LOG_JOIN_FIELDS)
# Combine logs_for_join and decoded events side by side
combined_logs = pa.RecordBatch.from_arrays(
arrays=logs_for_join.columns + decoded.columns,
names=logs_for_join.schema.names + decoded.schema.names,
)
# Select only necessary columns from raw blocks for joining
blocks_for_join = blocks.select(BLOCK_FIELDS)
# Select only necessary columns from raw transactions for joining
transactions_for_join = transactions.select(TRANSACTION_FIELDS)
# Convert RecordBatches to Tables
logs_table = pa.Table.from_batches([combined_logs])
blocks_table = pa.Table.from_batches([blocks_for_join])
transactions_table = pa.Table.from_batches([transactions_for_join])
# Rename 'number' to 'block_number' in blocks_table using dictionary mapping
blocks_table = blocks_table.rename_columns({"number": "block_number"})
# Join transactions with blocks on 'block_number'
tx_with_blocks = transactions_table.join(blocks_table, keys="block_number")
# Join logs with transactions+blocks on 'block_number' and 'transaction_index'
merged = logs_table.join(tx_with_blocks, keys=["block_number", "transaction_index"])
return merged
async def run(provider: ingest.ProviderConfig):
"""Main function to fetch data, merge raw data, and then process the merged table."""
async for batch in fetch_batches(provider):
# Retrieve raw data directly from the batch (without pre-processing)
raw_logs = batch["logs"]
raw_blocks = batch["blocks"]
raw_transactions = batch["transactions"]
# Compute decoded events from raw_logs without further processing
decoded = cherry_core.evm_decode_events(signature, raw_logs)
# Merge raw logs, decoded events, raw blocks, and raw transactions
merged = merge(decoded, raw_logs, raw_blocks, raw_transactions)
# Apply further processing to the merged table, passing the event signature
processed = process(merged, event_signature=signature)
# Convert to Polars DataFrame for display
df = pl.from_arrow(processed)
pl.Config.set_tbl_cols(100)
print(df)
# For demonstration, break after first batch
break
# Run the script
asyncio.run(
run(
ingest.ProviderConfig(
kind=ingest.ProviderKind.HYPERSYNC,
url=f"https://{network}.hypersync.xyz",
)
)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment