Created
February 1, 2023 13:12
-
-
Save jorisvandenbossche/fa2f72f07be0a328ab2454605ff34d77 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import pyarrow as pa | |
class InnerType(pa.ExtensionType): | |
def __init__(self): | |
pa.ExtensionType.__init__(self, pa.int64(), 'test.inner_type') | |
def __arrow_ext_serialize__(self): | |
return b"" | |
@classmethod | |
def __arrow_ext_deserialize__(cls, storage_type, serialized): | |
return InnerType() | |
def __eq__(self, other): | |
if isinstance(other, pa.BaseExtensionType): | |
return type(self) == type(other) | |
else: | |
return NotImplemented | |
class OuterAnnotatedType(pa.ExtensionType): | |
def __init__(self, storage_type, metadata): | |
self._metadata = metadata | |
pa.ExtensionType.__init__(self, storage_type, 'test.outer_annotated_type') | |
@property | |
def metadata(self): | |
return self._metadata | |
def __arrow_ext_serialize__(self): | |
metadata = {"metadata": self._metadata} | |
return json.dumps(metadata).encode() | |
@classmethod | |
def __arrow_ext_deserialize__(cls, storage_type, serialized): | |
metadata = json.loads(serialized.decode()) | |
return OuterAnnotatedType(storage_type, metadata["metadata"]) | |
def __eq__(self, other): | |
if isinstance(other, pa.BaseExtensionType): | |
return (type(self) == type(other) and | |
self.metadata == other.metadata) | |
else: | |
return NotImplemented | |
t1 = InnerType() | |
pa.register_extension_type(t1) | |
t2 = OuterAnnotatedType(pa.int64(), "metadata") | |
pa.register_extension_type(t2) | |
storage = pa.array([1, 2, 3, 4], pa.int64()) | |
arr1 = pa.ExtensionArray.from_storage(t1, storage) | |
t2 = OuterAnnotatedType(t1, "custom_info") | |
arr2 = pa.ExtensionArray.from_storage(t2, arr1) | |
batch = pa.RecordBatch.from_arrays([arr2], ["ext"]) | |
# passthrough IPC | |
stream = pa.BufferOutputStream() | |
with pa.RecordBatchStreamWriter(stream, batch.schema) as writer: | |
writer.write_batch(batch) | |
buf = stream.getvalue() | |
reader = pa.RecordBatchStreamReader(buf) | |
batch2 = reader.read_next_batch() | |
print(batch["ext"].type.storage_type) | |
# extension<test.inner_type<InnerType>> | |
print(batch2["ext"].type.storage_type) | |
# DataType(int64) | |
pa.unregister_extension_type('test.inner_type') | |
pa.unregister_extension_type('test.outer_annotated_type') | |
stream = pa.BufferOutputStream() | |
with pa.RecordBatchStreamWriter(stream, batch.schema) as writer: | |
writer.write_batch(batch) | |
buf = stream.getvalue() | |
reader = pa.RecordBatchStreamReader(buf) | |
batch2 = reader.read_next_batch() | |
# now the main type is already int64 | |
print(batch2["ext"].type) | |
# DataType(int64) | |
# the field metadata only has the metadata for the outer extension type | |
batch2.schema.field("ext").metadata | |
# {b'ARROW:extension:metadata': b'{"metadata": "custom_info"}', | |
# b'ARROW:extension:name': b'test.outer_annotated_type'} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment