Created
February 27, 2025 11:30
-
-
Save elyase/3ff1dea14f4c665d11e683cecae849ea to your computer and use it in GitHub Desktop.
Example of how to fetch and join raw tables in cherry
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import cherry_core | |
import polars as pl | |
import pyarrow as pa | |
import pyarrow.compute as pc | |
from cherry_core import ingest | |
from pyarrow import RecordBatch | |
# Configuration for teaching purposes | |
signature = "Transfer(address indexed src, address indexed dst, uint wad)" | |
contract_address = "0x833589fCD6eDb6E08f4c7C32D4f71b54bdA02913" # WETH on Base network | |
start_block = 20123123 | |
end_block = 20123223 | |
network = "base" | |
# Field definitions to reduce duplication | |
LOG_FIELDS = [ | |
"block_number", | |
"transaction_index", | |
"log_index", | |
"address", | |
"data", | |
"topic0", | |
"topic1", | |
"topic2", | |
"topic3", | |
] | |
LOG_JOIN_FIELDS = ["block_number", "transaction_index", "log_index", "address"] | |
BLOCK_FIELDS = ["number", "timestamp"] | |
TRANSACTION_FIELDS = [ | |
"block_number", | |
"transaction_index", | |
"hash", | |
"from", | |
"to", | |
"value", | |
"effective_gas_price", | |
] | |
def create_fields_config(log_fields, block_fields, transaction_fields): | |
"""Create an ingest.Fields configuration from field lists.""" | |
# Handle the 'from' field which is a Python keyword | |
tx_fields_dict = { | |
field if field != "from" else "from_": True for field in transaction_fields | |
} | |
return ingest.Fields( | |
block=ingest.BlockFields(**{field: True for field in block_fields}), | |
log=ingest.LogFields(**{field: True for field in log_fields}), | |
transaction=ingest.TransactionFields(**tx_fields_dict), | |
) | |
def get_hex_columns_from_signature( | |
signature: str, | |
standard_fields: list[str] = ["address", "from", "to", "hash"], | |
) -> list[str]: | |
""" | |
Extract hex-encodable fields from event signature and standard Ethereum fields. | |
Args: | |
signature: Event signature (e.g., "Transfer(address indexed src, address indexed dst, uint wad)") | |
standard_fields: Common Ethereum fields to include | |
Returns: | |
List of field names that should be hex-encoded | |
""" | |
import re | |
# Extract address fields from signature and combine with standard fields | |
address_fields = re.findall(r"address(?:\s+indexed)?\s+(\w+)", signature) | |
hex_fields = set(address_fields) | set(standard_fields) | |
return list(hex_fields) | |
def get_hex_columns_from_record_batch( | |
rb: RecordBatch, | |
standard_fields: list[str] = ["address", "from", "to", "hash"], | |
) -> list[str]: | |
""" | |
Extract hex-encodable fields from a RecordBatch schema. | |
Identifies binary fields that are 20 bytes (Ethereum addresses) or | |
32 bytes (transaction hashes) in length. | |
Args: | |
rb: The PyArrow RecordBatch to analyze | |
standard_fields: Common Ethereum fields to include if they exist in the schema | |
Returns: | |
List of field names that should be hex-encoded | |
""" | |
hex_fields = set() | |
# Add standard fields if they exist in the schema | |
for field_name in standard_fields: | |
if field_name in rb.schema.names: | |
hex_fields.add(field_name) | |
# Check each binary field for correct length | |
for field in rb.schema: | |
if field.name in hex_fields: | |
continue | |
if pa.types.is_binary(field.type): | |
column = rb.column(field.name) | |
if len(column) == 0: | |
continue | |
# Find first non-null value | |
for value in column: | |
if value is not None and not pc.is_null(value).as_py(): | |
# Check if length is 20 or 32 bytes | |
if len(value.as_py()) in (20, 32): | |
hex_fields.add(field.name) | |
break | |
return list(hex_fields) | |
def cast_decimal256_to_decimal128(rb: RecordBatch) -> RecordBatch: | |
"""Cast decimal256 fields in PyArrow RecordBatch to decimal128(38,0).""" | |
fields = [ | |
pa.field(f.name, pa.decimal128(38, 0)) | |
if isinstance(f.type, pa.Decimal256Type) | |
else f | |
for f in rb.schema | |
] | |
schema = pa.schema(fields) | |
return rb.cast(schema) | |
def cast_decimal256_to_timestamp( | |
rb: RecordBatch, timestamp_field_names=["timestamp"], unit="s" | |
) -> RecordBatch: | |
"""Convert specified decimal256 fields to timestamp in two steps.""" | |
# Step 1: cast to int64 | |
int_fields = [ | |
pa.field(f.name, pa.int64()) if f.name in timestamp_field_names else f | |
for f in rb.schema | |
] | |
rb_int = rb.cast(pa.schema(int_fields)) | |
# Step 2: cast int64 to timestamp | |
timestamp_fields = [ | |
pa.field(f.name, pa.timestamp(unit)) if f.name in timestamp_field_names else f | |
for f in rb_int.schema | |
] | |
return rb_int.cast(pa.schema(timestamp_fields)) | |
def encode_hex_columns( | |
rb: RecordBatch, hex_field_names=None, signature=None | |
) -> RecordBatch: | |
"""Convert binary hex fields to properly formatted hexadecimal strings with '0x' prefix. | |
This is particularly useful for Ethereum addresses and transaction hashes | |
which are stored as binary but need to be displayed as hex strings with '0x' prefix. | |
Args: | |
rb: The PyArrow RecordBatch to process | |
hex_field_names: Optional list of field names to encode as hex. If None, fields will be | |
automatically identified. | |
signature: Optional event signature to help identify hex fields from the schema | |
Returns: | |
A new RecordBatch with binary fields converted to hex strings | |
""" | |
# If no hex field names provided, automatically identify them | |
if hex_field_names is None: | |
if signature: | |
# Use the signature-based approach | |
hex_field_names = get_hex_columns_from_signature( | |
signature, | |
standard_fields=["address", "from", "to", "hash", "sender", "receiver"], | |
) | |
print( | |
f"Automatically identified hex columns from signature: {hex_field_names}" | |
) | |
else: | |
# Use the RecordBatch schema-based approach | |
hex_field_names = get_hex_columns_from_record_batch( | |
rb, | |
standard_fields=["address", "from", "to", "hash", "sender", "receiver"], | |
) | |
print( | |
f"Automatically identified hex columns from schema: {hex_field_names}" | |
) | |
# Filter hex_field_names to only include fields that actually exist in the RecordBatch | |
existing_hex_fields = [ | |
field for field in hex_field_names if field in rb.schema.names | |
] | |
# Create a copy of the input RecordBatch | |
arrays = list(rb.columns) | |
schema_fields = list(rb.schema) | |
# For each field that should be converted to hex | |
for i, field in enumerate(schema_fields): | |
if field.name in existing_hex_fields: | |
# Replace the array with its hex-encoded version (with 0x prefix) | |
arrays[i] = cherry_core.prefix_hex_encode_column(rb.column(field.name)) | |
# Create a new RecordBatch with the updated arrays | |
return pa.RecordBatch.from_arrays(arrays, names=rb.schema.names) | |
def add_derived_columns_arrow(table: pa.Table) -> pa.Table: | |
"""Add derived columns using PyArrow compute functions.""" | |
# Convert wad to float64 and divide by 1e18 | |
wad_array = table.column("wad") | |
wad_float = pc.cast(wad_array, pa.float64()) | |
amount_eth = pc.divide(wad_float, 1e18) | |
# Add the new column to the Table | |
return table.append_column("amount_eth", amount_eth) | |
async def fetch_batches(provider: ingest.ProviderConfig): | |
"""Async generator to fetch batches of data from the provider.""" | |
# Create fields configuration using our helper function | |
fields_config = create_fields_config(LOG_FIELDS, BLOCK_FIELDS, TRANSACTION_FIELDS) | |
stream = ingest.start_stream( | |
ingest.StreamConfig( | |
format=ingest.Format.EVM, | |
provider=provider, | |
query=ingest.EvmQuery( | |
from_block=start_block, | |
to_block=end_block, | |
logs=[ | |
ingest.LogRequest( | |
address=[contract_address], | |
event_signatures=[signature], | |
) | |
], | |
fields=fields_config, | |
), | |
) | |
) | |
while True: | |
batch = await stream.next() | |
if batch is None: | |
break | |
yield batch | |
def process(merged: pa.Table, event_signature: str = None) -> pa.Table: | |
"""Apply further processing steps to the merged table. | |
This function applies: | |
1. Casting decimal256 fields to decimal128 | |
2. Converting timestamp fields from decimal to timestamp | |
3. Encoding binary address and hash fields to hex strings with 0x prefix | |
4. Adding derived columns (amount_eth) | |
Args: | |
merged: The merged table to process | |
event_signature: Optional event signature to help identify hex fields | |
""" | |
# Get all batches in the table | |
batches = merged.to_batches() | |
print(f"Processing {len(batches)} batch(es) from the merged table") | |
# If we have an event signature, pre-identify hex columns | |
hex_columns = None | |
if event_signature: | |
hex_columns = get_hex_columns_from_signature( | |
event_signature, | |
standard_fields=["address", "from", "to", "hash", "sender", "receiver"], | |
) | |
print(f"Identified hex columns from signature: {hex_columns}") | |
# Process all batches in the table | |
processed_batches = [] | |
for i, record_batch in enumerate(batches): | |
# Cast decimal256 fields to decimal128 | |
record_batch = cast_decimal256_to_decimal128(record_batch) | |
# Convert timestamp fields from decimal to timestamp | |
record_batch = cast_decimal256_to_timestamp(record_batch, ["timestamp"], "s") | |
# Encode binary address and hash fields to hex strings with 0x prefix | |
record_batch = encode_hex_columns(record_batch, hex_field_names=hex_columns) | |
processed_batches.append(record_batch) | |
# Convert all processed batches back to a single Table | |
processed = pa.Table.from_batches(processed_batches) | |
# Add derived columns | |
return add_derived_columns_arrow(processed) | |
def merge( | |
decoded: RecordBatch, | |
logs: RecordBatch, | |
blocks: RecordBatch, | |
transactions: RecordBatch, | |
) -> pa.Table: | |
"""Merge raw logs, decoded events, raw blocks, and raw transactions. | |
This function combines the raw logs with the decoded events side by side, then joins with | |
raw blocks and raw transactions. Blocks table's 'number' column is renamed to 'block_number' | |
to match join keys. | |
""" | |
# Select only necessary columns from raw logs for joining | |
logs_for_join = logs.select(LOG_JOIN_FIELDS) | |
# Combine logs_for_join and decoded events side by side | |
combined_logs = pa.RecordBatch.from_arrays( | |
arrays=logs_for_join.columns + decoded.columns, | |
names=logs_for_join.schema.names + decoded.schema.names, | |
) | |
# Select only necessary columns from raw blocks for joining | |
blocks_for_join = blocks.select(BLOCK_FIELDS) | |
# Select only necessary columns from raw transactions for joining | |
transactions_for_join = transactions.select(TRANSACTION_FIELDS) | |
# Convert RecordBatches to Tables | |
logs_table = pa.Table.from_batches([combined_logs]) | |
blocks_table = pa.Table.from_batches([blocks_for_join]) | |
transactions_table = pa.Table.from_batches([transactions_for_join]) | |
# Rename 'number' to 'block_number' in blocks_table using dictionary mapping | |
blocks_table = blocks_table.rename_columns({"number": "block_number"}) | |
# Join transactions with blocks on 'block_number' | |
tx_with_blocks = transactions_table.join(blocks_table, keys="block_number") | |
# Join logs with transactions+blocks on 'block_number' and 'transaction_index' | |
merged = logs_table.join(tx_with_blocks, keys=["block_number", "transaction_index"]) | |
return merged | |
async def run(provider: ingest.ProviderConfig): | |
"""Main function to fetch data, merge raw data, and then process the merged table.""" | |
async for batch in fetch_batches(provider): | |
# Retrieve raw data directly from the batch (without pre-processing) | |
raw_logs = batch["logs"] | |
raw_blocks = batch["blocks"] | |
raw_transactions = batch["transactions"] | |
# Compute decoded events from raw_logs without further processing | |
decoded = cherry_core.evm_decode_events(signature, raw_logs) | |
# Merge raw logs, decoded events, raw blocks, and raw transactions | |
merged = merge(decoded, raw_logs, raw_blocks, raw_transactions) | |
# Apply further processing to the merged table, passing the event signature | |
processed = process(merged, event_signature=signature) | |
# Convert to Polars DataFrame for display | |
df = pl.from_arrow(processed) | |
pl.Config.set_tbl_cols(100) | |
print(df) | |
# For demonstration, break after first batch | |
break | |
# Run the script | |
asyncio.run( | |
run( | |
ingest.ProviderConfig( | |
kind=ingest.ProviderKind.HYPERSYNC, | |
url=f"https://{network}.hypersync.xyz", | |
) | |
) | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment