Last active
December 24, 2024 04:08
-
-
Save sam-goodwin/9b5ae19cc59f1349f362823454e31376 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from types import UnionType | |
from typing import Any, Optional, Union, get_args, get_origin | |
import numpy as np | |
import pyarrow as pa | |
from pandera import Check, DataFrameModel, DataFrameSchema, Index, MultiIndex, dtypes | |
from pandera.engines import pandas_engine | |
from pandera.engines.engine import _is_namedtuple, _is_typeddict | |
from pandera.typing import Series | |
from pydantic import BaseModel | |
def to_pyarrow_schema( | |
model: type[DataFrameModel] | DataFrameSchema, | |
preserve_index: Optional[bool] = None, | |
) -> pa.Schema: | |
"""Convert a :class:`~pandera.schemas.DataFrameSchema` to `pa.Schema`. | |
:param dataframe_schema: schema to convert to `pa.Schema` | |
:param preserve_index: whether to store the index as an additional column | |
(or columns, for MultiIndex) in the resulting Table. The default of | |
None will store the index as a column, except for RangeIndex which is | |
stored as metadata only. Use `preserve_index=True` to force it to be | |
stored as a column. | |
:returns: `pa.Schema` representation of DataFrameSchema | |
""" | |
dataframe_schema = model.to_schema() if isinstance(model, type) and issubclass(model, DataFrameModel) else model | |
# List of columns that will be present in the pyarrow schema | |
columns: dict[str, Series | Index] = dataframe_schema.columns | |
# pyarrow schema metadata | |
metadata: dict[str, bytes] = {} | |
index = dataframe_schema.index | |
if index is None: | |
if preserve_index: | |
# TODO(sam): need a unit test covering this | |
# Create column for RangeIndex | |
name = _get_index_name(0) | |
columns[name] = Index(dtypes.Int64, nullable=False, name=name) | |
else: | |
# Only preserve metadata of index | |
meta_val = b'[{"kind": "range", "name": pa.null, "step": 1}]' | |
metadata["index_columns"] = meta_val | |
elif preserve_index is not False: | |
# Add column(s) for index(es) | |
if isinstance(index, Index): | |
name = index.name or _get_index_name(0) | |
# Ensure index is added at dictionary beginning | |
columns = {name: index, **columns} | |
elif isinstance(index, MultiIndex): | |
for i, value in enumerate(reversed(index.indexes)): | |
name = value.name or _get_index_name(i) | |
columns = {name: value, **columns} | |
return pa.schema( | |
[pa.field(name, field_to_arrow(field), field.nullable) for name, field in columns.items()], | |
metadata=metadata, | |
) | |
def _get_index_name(level: int) -> str: | |
"""Generate an index name for pyarrow if none is specified""" | |
return f"__index_level_{level}__" | |
def field_to_arrow(field: Series | Index) -> pa.DataType: | |
if isinstance(field, Index): | |
# TODO(sam): not sure how to store indexes in pyarrow | |
msg = "Index conversion to pyarrow not yet implemented" | |
raise NotImplementedError(msg) | |
dtype = field.dtype | |
checks = field.checks | |
# required for `typing.List` | |
if hasattr(dtype, "generic_type") and dtype.generic_type: | |
return type_to_arrow(dtype.generic_type, checks) | |
# required for TypedDict and NamedTuple | |
if hasattr(dtype, "special_type") and dtype.special_type: | |
return type_to_arrow(dtype.special_type, checks) | |
return type_to_arrow(dtype, checks) | |
def is_union(origin: Any) -> bool: | |
return origin == Union or origin is Union or origin is UnionType | |
def type_to_arrow(some_type: Any, checks: list[Check] | None = None) -> pa.DataType: | |
origin = get_origin(some_type) | |
if origin is list: | |
element_type = get_args(some_type)[0] | |
return pa.list_(type_to_arrow(element_type)) | |
elif origin is dict: | |
key_type, value_type = get_args(some_type) | |
return pa.map_(type_to_arrow(key_type), type_to_arrow(value_type)) | |
elif origin is tuple: | |
# TODO(sam): best we can do is convert a tuple[x, y] to a list[z < x & y] | |
# because PyArrow does not support tuples | |
# alternative is a class Centroid(NamedTuple) but that requires converting tuple[x, y] to Centroid(x, y) | |
types = list(get_args(some_type)) | |
element_type = infer_type(types) | |
return pa.list_(type_to_arrow(element_type)) | |
elif isinstance(some_type, type) and ( | |
_is_namedtuple(some_type) or _is_typeddict(some_type) or issubclass(some_type, BaseModel) | |
): | |
annotations = some_type.__annotations__.items() | |
def struct_field_to_arrow_field(field_name: str, field_type: Any) -> pa.Field: | |
if is_union(get_origin(field_type)): | |
items = get_args(field_type) | |
if len(items) > 2: # noqa: PLR2004 - we only support T | None, None | T | |
msg = "Cannot handle unions with more than two types" | |
raise ValueError(msg) | |
is_optional = any(item is type(None) for item in items) | |
if is_optional: | |
return pa.field( | |
field_name, | |
type_to_arrow(next(item for item in items if item is not type(None))), | |
nullable=True, | |
) | |
else: | |
msg = "Union types can only be used to represent Optional[T]" | |
raise ValueError(msg) | |
else: | |
return pa.field( | |
field_name, | |
type_to_arrow(field_type), | |
nullable=False, | |
) | |
return pa.struct( | |
[struct_field_to_arrow_field(field_name, field_type) for field_name, field_type in annotations], | |
) | |
elif isinstance(some_type, pandas_engine.Decimal): | |
# TODO(sam): 128 or 256 bits? | |
return pa.decimal128(some_type.precision, some_type.scale) | |
elif isinstance(some_type, pandas_engine.Date): | |
# TODO(sam): is date64 right, what about date32? | |
return pa.date64() | |
elif isinstance(some_type, dtypes.Category): | |
# TODO(sam): surely there is a utility in pandera I can use to infer this type | |
# when debugging, some_type.categories was None | |
inferred_types = ( | |
[ | |
# get the type if each item in the `isin` check | |
type_to_arrow(type(value)) | |
for check in checks | |
if check.name == "isin" and check.statistics is not None and "allowed_values" in check.statistics | |
for value in check.statistics["allowed_values"] | |
] | |
if checks is not None | |
else [] | |
) | |
inferred_type = infer_type(inferred_types) | |
return pa.dictionary( | |
pa.int8(), | |
inferred_type, | |
ordered=some_type.ordered, | |
) | |
# fall back to pandas dtype and map that to pyarrow | |
pandas_dtype = pandas_engine.Engine.dtype(some_type).type | |
if pandas_dtype.type == np.object_: | |
return pa.string() | |
else: | |
return pa.from_numpy_dtype(pandas_dtype) | |
def infer_type(types: list[type]) -> pa.DataType: | |
# de-dupe them to a unique set of types | |
unique_types = list(set(types)) | |
return ( | |
# we can infer the type if they are all the same (e.g. all strings or all ints) | |
unique_types[0] | |
if len(unique_types) == 1 | |
# TODO(sam): re-visit this behavior | |
# there is no "or" type in pyarrow - we use string to match np.object_ (top type) behavior. | |
# we could perhaps widen the type to the most precise common ancestor | |
else pa.string() | |
# I originally thought this would work, but a union takes named fields, not just types. | |
# We don't have field names, only type names. | |
# else pa.union(*unique_types) | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from typing import List, NamedTuple, Optional | |
import pandera as pd | |
import pandera.typing as pdt | |
import pyarrow # noqa: ICN001 - conflcits with pandera | |
from typing_extensions import TypedDict | |
from noetik_pipeline_methods.assets.orion_human.tissue_props import OrionHumanTissueProps | |
# See: https://github.com/unionai-oss/pandera/pull/1556 | |
# from pandera.typing import Series | |
from noetik_pipeline_methods.io.pandera_patch import Series | |
from noetik_pipeline_methods.io.pandera_to_arrow import to_pyarrow_schema | |
from noetik_pipeline_methods.io.table_io import pyarrow_schema_to_glue_columns | |
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1" | |
import pandera as pa | |
class TodoItemDict(TypedDict): | |
name: str | |
priority: int | |
pd_uint8: pdt.UInt8 | |
optional: Optional[str] | |
int_or_none: int | None | |
class TodoItemTuple(NamedTuple): | |
name: str | |
priority: int | |
pd_uint8: pdt.UInt8 | |
optional: Optional[str] | |
categories = ["A", "B", "C", "D"] | |
heterogenous_categories = ["A", "B", 1, 2] | |
class TodoList(pa.DataFrameModel): | |
# TODO(sam): remove the redundant Series[T] wrapper when fixed: | |
# https://github.com/unionai-oss/pandera/issues/1546 | |
bool_: Series[pdt.Bool] = pa.Field() | |
bool_list: Series[list[pdt.Bool]] = pa.Field() | |
float32_list: Series[list[pdt.Float32]] = pa.Field() | |
float32: Series[pdt.Float32] = pa.Field() | |
float64_list: Series[list[pdt.Float64]] = pa.Field() | |
float64: Series[pdt.Float64] = pa.Field() | |
int_list: Series[list[int]] = pa.Field() | |
int16_List: Series[List[pdt.Int16]] = pa.Field() # noqa: N815 - uppercase List matches what is being tested | |
int16: Series[pdt.Int16] = pa.Field() | |
int32_list: Series[list[pdt.Int32]] = pa.Field() | |
int32: Series[pdt.Int32] = pa.Field() | |
int64_list: Series[list[pdt.Int64]] = pa.Field() | |
int64: Series[pdt.Int64] = pa.Field() | |
int8_list: Series[list[pdt.Int8]] = pa.Field() | |
int8: Series[pdt.Int8] = pa.Field() | |
str_list: Series[list[str]] = pa.Field() | |
string_list: Series[list[pdt.String]] = pa.Field() | |
string: Series[pdt.String] = pa.Field() | |
uint16_list: Series[list[pdt.UInt16]] = pa.Field() | |
uint16: Series[pdt.UInt16] = pa.Field() | |
uint32_list: Series[list[pdt.UInt32]] = pa.Field() | |
uint32: Series[pdt.UInt32] = pa.Field() | |
uint64_list: Series[list[pdt.UInt64]] = pa.Field() | |
uint64: Series[pdt.UInt64] = pa.Field() | |
uint8_list: Series[list[pdt.UInt8]] = pa.Field() | |
uint8: Series[pdt.UInt8] = pa.Field() | |
timestamp: Series[pd.Timestamp] = pa.Field() | |
date: Series[pdt.Date] = pa.Field() | |
datetime: Series[pdt.DateTime] = pa.Field() | |
# TODO(sam): how to specify precision? | |
decimal: Series[pdt.Decimal] = pa.Field() | |
category: Series[pa.Category] = pa.Field(isin=categories, coerce=True) | |
heterogenous_category: Series[pa.Category] = pa.Field(isin=heterogenous_categories, coerce=True) | |
named_tuple: Series[TodoItemTuple] = pa.Field() | |
named_tuple_list: Series[list[TodoItemTuple]] = pa.Field() | |
typed_dict: Series[TodoItemDict] = pa.Field() | |
typed_dict_list: Series[list[TodoItemDict]] = pa.Field() | |
map_str_str: Series[dict[str, pdt.String]] = pa.Field() | |
map_str_dict: Series[dict[str, TodoItemDict]] = pa.Field() | |
map_str_tuple: Series[dict[str, TodoItemTuple]] = pa.Field() | |
def test_to_arrow() -> None: | |
schema = to_pyarrow_schema(TodoList) | |
expected_schema = pyarrow.schema( | |
[ | |
pyarrow.field("bool_", pyarrow.bool_(), nullable=False), | |
pyarrow.field("bool_list", pyarrow.list_(pyarrow.bool_()), nullable=False), | |
pyarrow.field("float32_list", pyarrow.list_(pyarrow.float32()), nullable=False), | |
pyarrow.field("float32", pyarrow.float32(), nullable=False), | |
pyarrow.field("float64_list", pyarrow.list_(pyarrow.float64()), nullable=False), | |
pyarrow.field("float64", pyarrow.float64(), nullable=False), | |
pyarrow.field("int_list", pyarrow.list_(pyarrow.int64()), nullable=False), | |
pyarrow.field("int16_List", pyarrow.list_(pyarrow.int16()), nullable=False), | |
pyarrow.field("int16", pyarrow.int16(), nullable=False), | |
pyarrow.field("int32_list", pyarrow.list_(pyarrow.int32()), nullable=False), | |
pyarrow.field("int32", pyarrow.int32(), nullable=False), | |
pyarrow.field("int64_list", pyarrow.list_(pyarrow.int64()), nullable=False), | |
pyarrow.field("int64", pyarrow.int64(), nullable=False), | |
pyarrow.field("int8_list", pyarrow.list_(pyarrow.int8()), nullable=False), | |
pyarrow.field("int8", pyarrow.int8(), nullable=False), | |
pyarrow.field("str_list", pyarrow.list_(pyarrow.string()), nullable=False), | |
pyarrow.field("string_list", pyarrow.list_(pyarrow.string()), nullable=False), | |
pyarrow.field("string", pyarrow.string(), nullable=False), | |
pyarrow.field("uint16_list", pyarrow.list_(pyarrow.uint16()), nullable=False), | |
pyarrow.field("uint16", pyarrow.uint16(), nullable=False), | |
pyarrow.field("uint32_list", pyarrow.list_(pyarrow.uint32()), nullable=False), | |
pyarrow.field("uint32", pyarrow.uint32(), nullable=False), | |
pyarrow.field("uint64_list", pyarrow.list_(pyarrow.uint64()), nullable=False), | |
pyarrow.field("uint64", pyarrow.uint64(), nullable=False), | |
pyarrow.field("uint8_list", pyarrow.list_(pyarrow.uint8()), nullable=False), | |
pyarrow.field("uint8", pyarrow.uint8(), nullable=False), | |
pyarrow.field("timestamp", pyarrow.timestamp("ns"), nullable=False), | |
pyarrow.field("date", pyarrow.date64(), nullable=False), | |
pyarrow.field("datetime", pyarrow.timestamp("ns"), nullable=False), | |
# these defaults come from pandera | |
pyarrow.field("decimal", pyarrow.decimal128(28, 0), nullable=False), | |
pyarrow.field( | |
"category", | |
pyarrow.dictionary(pyarrow.int8(), pyarrow.string(), ordered=False), | |
nullable=False, | |
), | |
pyarrow.field( | |
"heterogenous_category", | |
pyarrow.dictionary(pyarrow.int8(), pyarrow.string(), ordered=False), | |
nullable=False, | |
), | |
pyarrow.field( | |
"named_tuple", | |
pyarrow.struct( | |
[ | |
pyarrow.field("name", pyarrow.string(), nullable=False), | |
pyarrow.field("priority", pyarrow.int64(), nullable=False), | |
pyarrow.field("pd_uint8", pyarrow.uint8(), nullable=False), | |
pyarrow.field("optional", pyarrow.string(), nullable=True), | |
], | |
), | |
nullable=False, | |
), | |
pyarrow.field( | |
"named_tuple_list", | |
pyarrow.list_( | |
pyarrow.struct( | |
[ | |
pyarrow.field("name", pyarrow.string(), nullable=False), | |
pyarrow.field("priority", pyarrow.int64(), nullable=False), | |
pyarrow.field("pd_uint8", pyarrow.uint8(), nullable=False), | |
pyarrow.field("optional", pyarrow.string(), nullable=True), | |
], | |
), | |
), | |
nullable=False, | |
), | |
pyarrow.field( | |
"typed_dict", | |
pyarrow.struct( | |
[ | |
pyarrow.field("name", pyarrow.string(), nullable=False), | |
pyarrow.field("priority", pyarrow.int64(), nullable=False), | |
pyarrow.field("pd_uint8", pyarrow.uint8(), nullable=False), | |
pyarrow.field("optional", pyarrow.string(), nullable=True), | |
pyarrow.field("int_or_none", pyarrow.int64(), nullable=True), | |
], | |
), | |
nullable=False, | |
), | |
pyarrow.field( | |
"typed_dict_list", | |
pyarrow.list_( | |
pyarrow.struct( | |
[ | |
pyarrow.field("name", pyarrow.string(), nullable=False), | |
pyarrow.field("priority", pyarrow.int64(), nullable=False), | |
pyarrow.field("pd_uint8", pyarrow.uint8(), nullable=False), | |
pyarrow.field("optional", pyarrow.string(), nullable=True), | |
pyarrow.field("int_or_none", pyarrow.int64(), nullable=True), | |
], | |
), | |
), | |
nullable=False, | |
), | |
pyarrow.field( | |
"map_str_str", | |
pyarrow.map_( | |
pyarrow.string(), | |
pyarrow.string(), | |
), | |
nullable=False, | |
), | |
pyarrow.field( | |
"map_str_dict", | |
pyarrow.map_( | |
pyarrow.string(), | |
pyarrow.struct( | |
[ | |
pyarrow.field("name", pyarrow.string(), nullable=False), | |
pyarrow.field("priority", pyarrow.int64(), nullable=False), | |
pyarrow.field("pd_uint8", pyarrow.uint8(), nullable=False), | |
pyarrow.field("optional", pyarrow.string(), nullable=True), | |
pyarrow.field("int_or_none", pyarrow.int64(), nullable=True), | |
], | |
), | |
), | |
nullable=False, | |
), | |
pyarrow.field( | |
"map_str_tuple", | |
pyarrow.map_( | |
pyarrow.string(), | |
pyarrow.struct( | |
[ | |
pyarrow.field("name", pyarrow.string(), nullable=False), | |
pyarrow.field("priority", pyarrow.int64(), nullable=False), | |
pyarrow.field("pd_uint8", pyarrow.uint8(), nullable=False), | |
pyarrow.field("optional", pyarrow.string(), nullable=True), | |
], | |
), | |
), | |
nullable=False, | |
), | |
], | |
) | |
assert schema == expected_schema, "Generated schema does not match expected schema" | |
def test_orion_human_tissue_props() -> None: | |
schema = to_pyarrow_schema(OrionHumanTissueProps) | |
expected_schema = pyarrow.schema( | |
[ | |
pyarrow.field("area", pyarrow.float64(), nullable=False), | |
pyarrow.field("axis_major_length", pyarrow.float64(), nullable=False), | |
pyarrow.field("axis_minor_length", pyarrow.float64(), nullable=False), | |
pyarrow.field( | |
"centroid", | |
pyarrow.list_(pyarrow.float64()), | |
nullable=False, | |
), | |
pyarrow.field("eccentricity", pyarrow.float64(), nullable=False), | |
pyarrow.field("euler_number", pyarrow.int64(), nullable=False), | |
pyarrow.field("extent", pyarrow.float64(), nullable=False), | |
pyarrow.field("intensity_max", pyarrow.list_(pyarrow.float64()), nullable=False), | |
pyarrow.field("intensity_mean", pyarrow.list_(pyarrow.float64()), nullable=False), | |
pyarrow.field("intensity_min", pyarrow.list_(pyarrow.float64()), nullable=False), | |
pyarrow.field("label", pyarrow.int64(), nullable=False), | |
pyarrow.field("solidity", pyarrow.float64(), nullable=False), | |
pyarrow.field("entropy", pyarrow.list_(pyarrow.float64()), nullable=False), | |
pyarrow.field("noise", pyarrow.list_(pyarrow.float64()), nullable=False), | |
pyarrow.field("blur", pyarrow.list_(pyarrow.float64()), nullable=False), | |
pyarrow.field("low_contrast", pyarrow.list_(pyarrow.bool_()), nullable=False), | |
pyarrow.field("pixel_histogram", pyarrow.list_(pyarrow.list_(pyarrow.int64())), nullable=False), | |
pyarrow.field("core_image_id", pyarrow.string(), nullable=False), | |
], | |
) | |
assert schema == expected_schema, "Generated schema does not match expected schema" | |
def test_pyarrow_to_glue_column() -> None: | |
schema = to_pyarrow_schema(TodoList) | |
columns = pyarrow_schema_to_glue_columns(schema=schema) | |
assert columns == [ | |
{"Name": "bool_", "Type": "boolean"}, | |
{"Name": "bool_list", "Type": "array<boolean>"}, | |
{"Name": "float32_list", "Type": "array<float>"}, | |
{"Name": "float32", "Type": "float"}, | |
{"Name": "float64_list", "Type": "array<double>"}, | |
{"Name": "float64", "Type": "double"}, | |
{"Name": "int_list", "Type": "array<bigint>"}, | |
{"Name": "int16_List", "Type": "array<smallint>"}, | |
{"Name": "int16", "Type": "smallint"}, | |
{"Name": "int32_list", "Type": "array<int>"}, | |
{"Name": "int32", "Type": "int"}, | |
{"Name": "int64_list", "Type": "array<bigint>"}, | |
{"Name": "int64", "Type": "bigint"}, | |
{"Name": "int8_list", "Type": "array<tinyint>"}, | |
{"Name": "int8", "Type": "tinyint"}, | |
{"Name": "str_list", "Type": "array<string>"}, | |
{"Name": "string_list", "Type": "array<string>"}, | |
{"Name": "string", "Type": "string"}, | |
{"Name": "uint16_list", "Type": "array<int>"}, | |
{"Name": "uint16", "Type": "int"}, | |
{"Name": "uint32_list", "Type": "array<bigint>"}, | |
{"Name": "uint32", "Type": "bigint"}, | |
{"Name": "uint64_list", "Type": "array<bigint>"}, | |
{"Name": "uint64", "Type": "bigint"}, | |
{"Name": "uint8_list", "Type": "array<smallint>"}, | |
{"Name": "uint8", "Type": "smallint"}, | |
{"Name": "timestamp", "Type": "timestamp"}, | |
{"Name": "date", "Type": "date"}, | |
{"Name": "datetime", "Type": "timestamp"}, | |
{"Name": "decimal", "Type": "decimal(28,0)"}, | |
{"Name": "category", "Type": "map<tinyint,string>"}, | |
{"Name": "heterogenous_category", "Type": "map<tinyint,string>"}, | |
{"Name": "named_tuple", "Type": "struct<name:string,priority:bigint,pd_uint8:smallint,optional:string>"}, | |
{ | |
"Name": "named_tuple_list", | |
"Type": "array<struct<name:string,priority:bigint,pd_uint8:smallint,optional:string>>", | |
}, | |
{ | |
"Name": "typed_dict", | |
"Type": "struct<name:string,priority:bigint,pd_uint8:smallint,optional:string,int_or_none:bigint>", | |
}, | |
{ | |
"Name": "typed_dict_list", | |
"Type": "array<struct<name:string,priority:bigint,pd_uint8:smallint,optional:string,int_or_none:bigint>>", | |
}, | |
{"Name": "map_str_str", "Type": "map<string,string>"}, | |
{ | |
"Name": "map_str_dict", | |
"Type": "map<string,struct<name:string,priority:bigint,pd_uint8:smallint,optional:string,int_or_none:bigint>>", # noqa: E501 | |
}, | |
{ | |
"Name": "map_str_tuple", | |
"Type": "map<string,struct<name:string,priority:bigint,pd_uint8:smallint,optional:string>>", | |
}, | |
] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment