Created
December 6, 2024 18:29
-
-
Save mrn-aglic/19a76ce59135b0d46ec9beac0158db97 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import typing | |
from dataclasses import asdict, dataclass, fields, is_dataclass | |
from types import SimpleNamespace | |
from typing import get_args, get_origin | |
import pyarrow as pa | |
def get_field_arrow_type(field): | |
"""Map a field's Python type to an Arrow type.""" | |
if is_dataclass(field.type): # Handle nested dataclasses | |
return DataclassArrowType(field.type).storage_type | |
origin = get_origin(field.type) # Check if it's a generic type (list, dict, etc.) | |
args = get_args(field.type) | |
if field.type == str: | |
return pa.string() | |
elif field.type == int: | |
return pa.int32() | |
elif field.type == float: | |
return pa.float64() | |
elif origin in (list, tuple): | |
if len(args) == 1: # e.g., list[int] | |
return pa.list_(get_field_arrow_type(SimpleNamespace(type=args[0]))) | |
else: | |
raise TypeError("Tuples with heterogeneous types are not supported.") | |
else: | |
raise TypeError(f"Unsupported field type: {field.type}") | |
class DataclassArrowScalar(pa.ExtensionScalar): | |
def as_py(self): | |
cls = self.type.dataclass | |
values = {} | |
for field in fields(cls): | |
pa_value = self.value.get(field.name) | |
origin = get_origin(field.type) | |
if is_dataclass(field.type): | |
field_type = DataclassArrowType(field.type) | |
py_value = pa_value.cast(field_type).as_py() | |
values[field.name] = py_value | |
elif origin in (list,): | |
args = get_args(field.type) | |
list_type = args[0] if len(args) == 1 else None | |
if list_type is not None and is_dataclass(list_type): | |
field_type = DataclassArrowType(list_type) | |
val = pa_value.values.cast(field_type) | |
py_value = val.to_pylist() | |
else: | |
py_value = pa_value.as_py() | |
values[field.name] = py_value | |
else: | |
values[field.name] = self.value.get(field.name).as_py() | |
return cls(**values) | |
# return cls(**self.value.as_py()) | |
DATACLASS_REGISTRY = {} | |
class DataclassArrowType(pa.ExtensionType): | |
def __init__(self, dataclass_type): | |
if is_dataclass(dataclass_type): | |
DATACLASS_REGISTRY[dataclass_type.__name__] = dataclass_type | |
field_types = [ | |
(field.name, get_field_arrow_type(field)) | |
for field in fields(DATACLASS_REGISTRY.get(dataclass_type.__name__)) | |
] | |
storage_type = pa.struct(field_types) | |
self._type = dataclass_type.__name__ | |
self._dataclass = dataclass_type | |
super().__init__(storage_type, f"dataclass-{self._type}") | |
@property | |
def dataclass(self): | |
return self._dataclass | |
def __arrow_ext_serialize__(self) -> bytes: | |
return self._dataclass.__name__.encode() | |
@classmethod | |
def __arrow_ext_deserialize__(cls, storage_type, serialized): | |
decoded = serialized.decode() | |
dataclass_type = DATACLASS_REGISTRY.get(decoded) | |
return DataclassArrowType(dataclass_type) | |
def __arrow_ext_scalar_class__(self): | |
return DataclassArrowScalar | |
@dataclass | |
class Address: | |
street: str | |
number: str | |
@dataclass | |
class Product: | |
id: int | |
name: str | |
@dataclass | |
class Person: | |
name: str | |
age: int | |
address: Address | |
products: list[Product] | |
person = Person( | |
name="John", | |
age=30, | |
address=Address(street="Main Street", number="123"), | |
products=[Product(id=99, name="shampoo")], | |
) | |
address_data_type = DataclassArrowType(Address) | |
product_data_type = DataclassArrowType(Product) | |
person_data_type = DataclassArrowType(Person) | |
pa.register_extension_type(address_data_type) | |
pa.register_extension_type(product_data_type) | |
pa.register_extension_type(person_data_type) | |
print(f"person_data_type.storage_type:> {person_data_type.storage_type}") | |
print(f"person_data_type.extension_name:> {person_data_type.extension_name}") | |
# storage_array = pa.array([asdict(person)], type=person_data_type.storage_type) | |
# print(type(storage_array)) | |
person_storage_array = pa.array([asdict(person)], type=person_data_type) | |
print(person_storage_array) | |
# ext_array = pa.ExtensionArray.from_storage(person_data_type, storage_array) | |
# print(type(ext_array)) | |
# ext_array = person_data_type.wrap_array(storage_array) | |
# print(type(ext_array)) | |
# print(ext_array) | |
# print(ext_array.to_pylist()) | |
persons_again = person_storage_array.to_pylist() | |
print(f"persons_again:> {persons_again}") | |
batch = pa.RecordBatch.from_arrays([person_storage_array], ["people"]) | |
sink = pa.BufferOutputStream() | |
with pa.RecordBatchStreamWriter(sink, batch.schema) as writer: | |
writer.write_batch(batch) | |
buf = sink.getvalue() | |
with pa.ipc.open_stream(buf) as reader: | |
result = reader.read_all() | |
people_col = result.column("people") | |
print(f"people_col.type:> {people_col.type}") | |
persons_after_buffer = people_col.cast(person_data_type).to_pylist() | |
print("persons_after_buffer:> ", persons_after_buffer) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment