Skip to content

Instantly share code, notes, and snippets.

@mrn-aglic
Created December 6, 2024 18:29
Show Gist options
  • Save mrn-aglic/19a76ce59135b0d46ec9beac0158db97 to your computer and use it in GitHub Desktop.
Save mrn-aglic/19a76ce59135b0d46ec9beac0158db97 to your computer and use it in GitHub Desktop.
import typing
from dataclasses import asdict, dataclass, fields, is_dataclass
from types import SimpleNamespace
from typing import get_args, get_origin
import pyarrow as pa
def get_field_arrow_type(field):
"""Map a field's Python type to an Arrow type."""
if is_dataclass(field.type): # Handle nested dataclasses
return DataclassArrowType(field.type).storage_type
origin = get_origin(field.type) # Check if it's a generic type (list, dict, etc.)
args = get_args(field.type)
if field.type == str:
return pa.string()
elif field.type == int:
return pa.int32()
elif field.type == float:
return pa.float64()
elif origin in (list, tuple):
if len(args) == 1: # e.g., list[int]
return pa.list_(get_field_arrow_type(SimpleNamespace(type=args[0])))
else:
raise TypeError("Tuples with heterogeneous types are not supported.")
else:
raise TypeError(f"Unsupported field type: {field.type}")
class DataclassArrowScalar(pa.ExtensionScalar):
def as_py(self):
cls = self.type.dataclass
values = {}
for field in fields(cls):
pa_value = self.value.get(field.name)
origin = get_origin(field.type)
if is_dataclass(field.type):
field_type = DataclassArrowType(field.type)
py_value = pa_value.cast(field_type).as_py()
values[field.name] = py_value
elif origin in (list,):
args = get_args(field.type)
list_type = args[0] if len(args) == 1 else None
if list_type is not None and is_dataclass(list_type):
field_type = DataclassArrowType(list_type)
val = pa_value.values.cast(field_type)
py_value = val.to_pylist()
else:
py_value = pa_value.as_py()
values[field.name] = py_value
else:
values[field.name] = self.value.get(field.name).as_py()
return cls(**values)
# return cls(**self.value.as_py())
DATACLASS_REGISTRY = {}
class DataclassArrowType(pa.ExtensionType):
def __init__(self, dataclass_type):
if is_dataclass(dataclass_type):
DATACLASS_REGISTRY[dataclass_type.__name__] = dataclass_type
field_types = [
(field.name, get_field_arrow_type(field))
for field in fields(DATACLASS_REGISTRY.get(dataclass_type.__name__))
]
storage_type = pa.struct(field_types)
self._type = dataclass_type.__name__
self._dataclass = dataclass_type
super().__init__(storage_type, f"dataclass-{self._type}")
@property
def dataclass(self):
return self._dataclass
def __arrow_ext_serialize__(self) -> bytes:
return self._dataclass.__name__.encode()
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
decoded = serialized.decode()
dataclass_type = DATACLASS_REGISTRY.get(decoded)
return DataclassArrowType(dataclass_type)
def __arrow_ext_scalar_class__(self):
return DataclassArrowScalar
@dataclass
class Address:
street: str
number: str
@dataclass
class Product:
id: int
name: str
@dataclass
class Person:
name: str
age: int
address: Address
products: list[Product]
person = Person(
name="John",
age=30,
address=Address(street="Main Street", number="123"),
products=[Product(id=99, name="shampoo")],
)
address_data_type = DataclassArrowType(Address)
product_data_type = DataclassArrowType(Product)
person_data_type = DataclassArrowType(Person)
pa.register_extension_type(address_data_type)
pa.register_extension_type(product_data_type)
pa.register_extension_type(person_data_type)
print(f"person_data_type.storage_type:> {person_data_type.storage_type}")
print(f"person_data_type.extension_name:> {person_data_type.extension_name}")
# storage_array = pa.array([asdict(person)], type=person_data_type.storage_type)
# print(type(storage_array))
person_storage_array = pa.array([asdict(person)], type=person_data_type)
print(person_storage_array)
# ext_array = pa.ExtensionArray.from_storage(person_data_type, storage_array)
# print(type(ext_array))
# ext_array = person_data_type.wrap_array(storage_array)
# print(type(ext_array))
# print(ext_array)
# print(ext_array.to_pylist())
persons_again = person_storage_array.to_pylist()
print(f"persons_again:> {persons_again}")
batch = pa.RecordBatch.from_arrays([person_storage_array], ["people"])
sink = pa.BufferOutputStream()
with pa.RecordBatchStreamWriter(sink, batch.schema) as writer:
writer.write_batch(batch)
buf = sink.getvalue()
with pa.ipc.open_stream(buf) as reader:
result = reader.read_all()
people_col = result.column("people")
print(f"people_col.type:> {people_col.type}")
persons_after_buffer = people_col.cast(person_data_type).to_pylist()
print("persons_after_buffer:> ", persons_after_buffer)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment