Skip to content

Instantly share code, notes, and snippets.

@sam-goodwin
Last active December 24, 2024 04:08
Show Gist options
  • Save sam-goodwin/9b5ae19cc59f1349f362823454e31376 to your computer and use it in GitHub Desktop.
Save sam-goodwin/9b5ae19cc59f1349f362823454e31376 to your computer and use it in GitHub Desktop.
from types import UnionType
from typing import Any, Optional, Union, get_args, get_origin
import numpy as np
import pyarrow as pa
from pandera import Check, DataFrameModel, DataFrameSchema, Index, MultiIndex, dtypes
from pandera.engines import pandas_engine
from pandera.engines.engine import _is_namedtuple, _is_typeddict
from pandera.typing import Series
from pydantic import BaseModel
def to_pyarrow_schema(
model: type[DataFrameModel] | DataFrameSchema,
preserve_index: Optional[bool] = None,
) -> pa.Schema:
"""Convert a :class:`~pandera.schemas.DataFrameSchema` to `pa.Schema`.
:param dataframe_schema: schema to convert to `pa.Schema`
:param preserve_index: whether to store the index as an additional column
(or columns, for MultiIndex) in the resulting Table. The default of
None will store the index as a column, except for RangeIndex which is
stored as metadata only. Use `preserve_index=True` to force it to be
stored as a column.
:returns: `pa.Schema` representation of DataFrameSchema
"""
dataframe_schema = model.to_schema() if isinstance(model, type) and issubclass(model, DataFrameModel) else model
# List of columns that will be present in the pyarrow schema
columns: dict[str, Series | Index] = dataframe_schema.columns
# pyarrow schema metadata
metadata: dict[str, bytes] = {}
index = dataframe_schema.index
if index is None:
if preserve_index:
# TODO(sam): need a unit test covering this
# Create column for RangeIndex
name = _get_index_name(0)
columns[name] = Index(dtypes.Int64, nullable=False, name=name)
else:
# Only preserve metadata of index
meta_val = b'[{"kind": "range", "name": pa.null, "step": 1}]'
metadata["index_columns"] = meta_val
elif preserve_index is not False:
# Add column(s) for index(es)
if isinstance(index, Index):
name = index.name or _get_index_name(0)
# Ensure index is added at dictionary beginning
columns = {name: index, **columns}
elif isinstance(index, MultiIndex):
for i, value in enumerate(reversed(index.indexes)):
name = value.name or _get_index_name(i)
columns = {name: value, **columns}
return pa.schema(
[pa.field(name, field_to_arrow(field), field.nullable) for name, field in columns.items()],
metadata=metadata,
)
def _get_index_name(level: int) -> str:
"""Generate an index name for pyarrow if none is specified"""
return f"__index_level_{level}__"
def field_to_arrow(field: Series | Index) -> pa.DataType:
if isinstance(field, Index):
# TODO(sam): not sure how to store indexes in pyarrow
msg = "Index conversion to pyarrow not yet implemented"
raise NotImplementedError(msg)
dtype = field.dtype
checks = field.checks
# required for `typing.List`
if hasattr(dtype, "generic_type") and dtype.generic_type:
return type_to_arrow(dtype.generic_type, checks)
# required for TypedDict and NamedTuple
if hasattr(dtype, "special_type") and dtype.special_type:
return type_to_arrow(dtype.special_type, checks)
return type_to_arrow(dtype, checks)
def is_union(origin: Any) -> bool:
return origin == Union or origin is Union or origin is UnionType
def type_to_arrow(some_type: Any, checks: list[Check] | None = None) -> pa.DataType:
origin = get_origin(some_type)
if origin is list:
element_type = get_args(some_type)[0]
return pa.list_(type_to_arrow(element_type))
elif origin is dict:
key_type, value_type = get_args(some_type)
return pa.map_(type_to_arrow(key_type), type_to_arrow(value_type))
elif origin is tuple:
# TODO(sam): best we can do is convert a tuple[x, y] to a list[z < x & y]
# because PyArrow does not support tuples
# alternative is a class Centroid(NamedTuple) but that requires converting tuple[x, y] to Centroid(x, y)
types = list(get_args(some_type))
element_type = infer_type(types)
return pa.list_(type_to_arrow(element_type))
elif isinstance(some_type, type) and (
_is_namedtuple(some_type) or _is_typeddict(some_type) or issubclass(some_type, BaseModel)
):
annotations = some_type.__annotations__.items()
def struct_field_to_arrow_field(field_name: str, field_type: Any) -> pa.Field:
if is_union(get_origin(field_type)):
items = get_args(field_type)
if len(items) > 2: # noqa: PLR2004 - we only support T | None, None | T
msg = "Cannot handle unions with more than two types"
raise ValueError(msg)
is_optional = any(item is type(None) for item in items)
if is_optional:
return pa.field(
field_name,
type_to_arrow(next(item for item in items if item is not type(None))),
nullable=True,
)
else:
msg = "Union types can only be used to represent Optional[T]"
raise ValueError(msg)
else:
return pa.field(
field_name,
type_to_arrow(field_type),
nullable=False,
)
return pa.struct(
[struct_field_to_arrow_field(field_name, field_type) for field_name, field_type in annotations],
)
elif isinstance(some_type, pandas_engine.Decimal):
# TODO(sam): 128 or 256 bits?
return pa.decimal128(some_type.precision, some_type.scale)
elif isinstance(some_type, pandas_engine.Date):
# TODO(sam): is date64 right, what about date32?
return pa.date64()
elif isinstance(some_type, dtypes.Category):
# TODO(sam): surely there is a utility in pandera I can use to infer this type
# when debugging, some_type.categories was None
inferred_types = (
[
# get the type if each item in the `isin` check
type_to_arrow(type(value))
for check in checks
if check.name == "isin" and check.statistics is not None and "allowed_values" in check.statistics
for value in check.statistics["allowed_values"]
]
if checks is not None
else []
)
inferred_type = infer_type(inferred_types)
return pa.dictionary(
pa.int8(),
inferred_type,
ordered=some_type.ordered,
)
# fall back to pandas dtype and map that to pyarrow
pandas_dtype = pandas_engine.Engine.dtype(some_type).type
if pandas_dtype.type == np.object_:
return pa.string()
else:
return pa.from_numpy_dtype(pandas_dtype)
def infer_type(types: list[type]) -> pa.DataType:
# de-dupe them to a unique set of types
unique_types = list(set(types))
return (
# we can infer the type if they are all the same (e.g. all strings or all ints)
unique_types[0]
if len(unique_types) == 1
# TODO(sam): re-visit this behavior
# there is no "or" type in pyarrow - we use string to match np.object_ (top type) behavior.
# we could perhaps widen the type to the most precise common ancestor
else pa.string()
# I originally thought this would work, but a union takes named fields, not just types.
# We don't have field names, only type names.
# else pa.union(*unique_types)
)
import os
from typing import List, NamedTuple, Optional
import pandera as pd
import pandera.typing as pdt
import pyarrow # noqa: ICN001 - conflcits with pandera
from typing_extensions import TypedDict
from noetik_pipeline_methods.assets.orion_human.tissue_props import OrionHumanTissueProps
# See: https://github.com/unionai-oss/pandera/pull/1556
# from pandera.typing import Series
from noetik_pipeline_methods.io.pandera_patch import Series
from noetik_pipeline_methods.io.pandera_to_arrow import to_pyarrow_schema
from noetik_pipeline_methods.io.table_io import pyarrow_schema_to_glue_columns
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
import pandera as pa
class TodoItemDict(TypedDict):
name: str
priority: int
pd_uint8: pdt.UInt8
optional: Optional[str]
int_or_none: int | None
class TodoItemTuple(NamedTuple):
name: str
priority: int
pd_uint8: pdt.UInt8
optional: Optional[str]
categories = ["A", "B", "C", "D"]
heterogenous_categories = ["A", "B", 1, 2]
class TodoList(pa.DataFrameModel):
# TODO(sam): remove the redundant Series[T] wrapper when fixed:
# https://github.com/unionai-oss/pandera/issues/1546
bool_: Series[pdt.Bool] = pa.Field()
bool_list: Series[list[pdt.Bool]] = pa.Field()
float32_list: Series[list[pdt.Float32]] = pa.Field()
float32: Series[pdt.Float32] = pa.Field()
float64_list: Series[list[pdt.Float64]] = pa.Field()
float64: Series[pdt.Float64] = pa.Field()
int_list: Series[list[int]] = pa.Field()
int16_List: Series[List[pdt.Int16]] = pa.Field() # noqa: N815 - uppercase List matches what is being tested
int16: Series[pdt.Int16] = pa.Field()
int32_list: Series[list[pdt.Int32]] = pa.Field()
int32: Series[pdt.Int32] = pa.Field()
int64_list: Series[list[pdt.Int64]] = pa.Field()
int64: Series[pdt.Int64] = pa.Field()
int8_list: Series[list[pdt.Int8]] = pa.Field()
int8: Series[pdt.Int8] = pa.Field()
str_list: Series[list[str]] = pa.Field()
string_list: Series[list[pdt.String]] = pa.Field()
string: Series[pdt.String] = pa.Field()
uint16_list: Series[list[pdt.UInt16]] = pa.Field()
uint16: Series[pdt.UInt16] = pa.Field()
uint32_list: Series[list[pdt.UInt32]] = pa.Field()
uint32: Series[pdt.UInt32] = pa.Field()
uint64_list: Series[list[pdt.UInt64]] = pa.Field()
uint64: Series[pdt.UInt64] = pa.Field()
uint8_list: Series[list[pdt.UInt8]] = pa.Field()
uint8: Series[pdt.UInt8] = pa.Field()
timestamp: Series[pd.Timestamp] = pa.Field()
date: Series[pdt.Date] = pa.Field()
datetime: Series[pdt.DateTime] = pa.Field()
# TODO(sam): how to specify precision?
decimal: Series[pdt.Decimal] = pa.Field()
category: Series[pa.Category] = pa.Field(isin=categories, coerce=True)
heterogenous_category: Series[pa.Category] = pa.Field(isin=heterogenous_categories, coerce=True)
named_tuple: Series[TodoItemTuple] = pa.Field()
named_tuple_list: Series[list[TodoItemTuple]] = pa.Field()
typed_dict: Series[TodoItemDict] = pa.Field()
typed_dict_list: Series[list[TodoItemDict]] = pa.Field()
map_str_str: Series[dict[str, pdt.String]] = pa.Field()
map_str_dict: Series[dict[str, TodoItemDict]] = pa.Field()
map_str_tuple: Series[dict[str, TodoItemTuple]] = pa.Field()
def test_to_arrow() -> None:
schema = to_pyarrow_schema(TodoList)
expected_schema = pyarrow.schema(
[
pyarrow.field("bool_", pyarrow.bool_(), nullable=False),
pyarrow.field("bool_list", pyarrow.list_(pyarrow.bool_()), nullable=False),
pyarrow.field("float32_list", pyarrow.list_(pyarrow.float32()), nullable=False),
pyarrow.field("float32", pyarrow.float32(), nullable=False),
pyarrow.field("float64_list", pyarrow.list_(pyarrow.float64()), nullable=False),
pyarrow.field("float64", pyarrow.float64(), nullable=False),
pyarrow.field("int_list", pyarrow.list_(pyarrow.int64()), nullable=False),
pyarrow.field("int16_List", pyarrow.list_(pyarrow.int16()), nullable=False),
pyarrow.field("int16", pyarrow.int16(), nullable=False),
pyarrow.field("int32_list", pyarrow.list_(pyarrow.int32()), nullable=False),
pyarrow.field("int32", pyarrow.int32(), nullable=False),
pyarrow.field("int64_list", pyarrow.list_(pyarrow.int64()), nullable=False),
pyarrow.field("int64", pyarrow.int64(), nullable=False),
pyarrow.field("int8_list", pyarrow.list_(pyarrow.int8()), nullable=False),
pyarrow.field("int8", pyarrow.int8(), nullable=False),
pyarrow.field("str_list", pyarrow.list_(pyarrow.string()), nullable=False),
pyarrow.field("string_list", pyarrow.list_(pyarrow.string()), nullable=False),
pyarrow.field("string", pyarrow.string(), nullable=False),
pyarrow.field("uint16_list", pyarrow.list_(pyarrow.uint16()), nullable=False),
pyarrow.field("uint16", pyarrow.uint16(), nullable=False),
pyarrow.field("uint32_list", pyarrow.list_(pyarrow.uint32()), nullable=False),
pyarrow.field("uint32", pyarrow.uint32(), nullable=False),
pyarrow.field("uint64_list", pyarrow.list_(pyarrow.uint64()), nullable=False),
pyarrow.field("uint64", pyarrow.uint64(), nullable=False),
pyarrow.field("uint8_list", pyarrow.list_(pyarrow.uint8()), nullable=False),
pyarrow.field("uint8", pyarrow.uint8(), nullable=False),
pyarrow.field("timestamp", pyarrow.timestamp("ns"), nullable=False),
pyarrow.field("date", pyarrow.date64(), nullable=False),
pyarrow.field("datetime", pyarrow.timestamp("ns"), nullable=False),
# these defaults come from pandera
pyarrow.field("decimal", pyarrow.decimal128(28, 0), nullable=False),
pyarrow.field(
"category",
pyarrow.dictionary(pyarrow.int8(), pyarrow.string(), ordered=False),
nullable=False,
),
pyarrow.field(
"heterogenous_category",
pyarrow.dictionary(pyarrow.int8(), pyarrow.string(), ordered=False),
nullable=False,
),
pyarrow.field(
"named_tuple",
pyarrow.struct(
[
pyarrow.field("name", pyarrow.string(), nullable=False),
pyarrow.field("priority", pyarrow.int64(), nullable=False),
pyarrow.field("pd_uint8", pyarrow.uint8(), nullable=False),
pyarrow.field("optional", pyarrow.string(), nullable=True),
],
),
nullable=False,
),
pyarrow.field(
"named_tuple_list",
pyarrow.list_(
pyarrow.struct(
[
pyarrow.field("name", pyarrow.string(), nullable=False),
pyarrow.field("priority", pyarrow.int64(), nullable=False),
pyarrow.field("pd_uint8", pyarrow.uint8(), nullable=False),
pyarrow.field("optional", pyarrow.string(), nullable=True),
],
),
),
nullable=False,
),
pyarrow.field(
"typed_dict",
pyarrow.struct(
[
pyarrow.field("name", pyarrow.string(), nullable=False),
pyarrow.field("priority", pyarrow.int64(), nullable=False),
pyarrow.field("pd_uint8", pyarrow.uint8(), nullable=False),
pyarrow.field("optional", pyarrow.string(), nullable=True),
pyarrow.field("int_or_none", pyarrow.int64(), nullable=True),
],
),
nullable=False,
),
pyarrow.field(
"typed_dict_list",
pyarrow.list_(
pyarrow.struct(
[
pyarrow.field("name", pyarrow.string(), nullable=False),
pyarrow.field("priority", pyarrow.int64(), nullable=False),
pyarrow.field("pd_uint8", pyarrow.uint8(), nullable=False),
pyarrow.field("optional", pyarrow.string(), nullable=True),
pyarrow.field("int_or_none", pyarrow.int64(), nullable=True),
],
),
),
nullable=False,
),
pyarrow.field(
"map_str_str",
pyarrow.map_(
pyarrow.string(),
pyarrow.string(),
),
nullable=False,
),
pyarrow.field(
"map_str_dict",
pyarrow.map_(
pyarrow.string(),
pyarrow.struct(
[
pyarrow.field("name", pyarrow.string(), nullable=False),
pyarrow.field("priority", pyarrow.int64(), nullable=False),
pyarrow.field("pd_uint8", pyarrow.uint8(), nullable=False),
pyarrow.field("optional", pyarrow.string(), nullable=True),
pyarrow.field("int_or_none", pyarrow.int64(), nullable=True),
],
),
),
nullable=False,
),
pyarrow.field(
"map_str_tuple",
pyarrow.map_(
pyarrow.string(),
pyarrow.struct(
[
pyarrow.field("name", pyarrow.string(), nullable=False),
pyarrow.field("priority", pyarrow.int64(), nullable=False),
pyarrow.field("pd_uint8", pyarrow.uint8(), nullable=False),
pyarrow.field("optional", pyarrow.string(), nullable=True),
],
),
),
nullable=False,
),
],
)
assert schema == expected_schema, "Generated schema does not match expected schema"
def test_orion_human_tissue_props() -> None:
schema = to_pyarrow_schema(OrionHumanTissueProps)
expected_schema = pyarrow.schema(
[
pyarrow.field("area", pyarrow.float64(), nullable=False),
pyarrow.field("axis_major_length", pyarrow.float64(), nullable=False),
pyarrow.field("axis_minor_length", pyarrow.float64(), nullable=False),
pyarrow.field(
"centroid",
pyarrow.list_(pyarrow.float64()),
nullable=False,
),
pyarrow.field("eccentricity", pyarrow.float64(), nullable=False),
pyarrow.field("euler_number", pyarrow.int64(), nullable=False),
pyarrow.field("extent", pyarrow.float64(), nullable=False),
pyarrow.field("intensity_max", pyarrow.list_(pyarrow.float64()), nullable=False),
pyarrow.field("intensity_mean", pyarrow.list_(pyarrow.float64()), nullable=False),
pyarrow.field("intensity_min", pyarrow.list_(pyarrow.float64()), nullable=False),
pyarrow.field("label", pyarrow.int64(), nullable=False),
pyarrow.field("solidity", pyarrow.float64(), nullable=False),
pyarrow.field("entropy", pyarrow.list_(pyarrow.float64()), nullable=False),
pyarrow.field("noise", pyarrow.list_(pyarrow.float64()), nullable=False),
pyarrow.field("blur", pyarrow.list_(pyarrow.float64()), nullable=False),
pyarrow.field("low_contrast", pyarrow.list_(pyarrow.bool_()), nullable=False),
pyarrow.field("pixel_histogram", pyarrow.list_(pyarrow.list_(pyarrow.int64())), nullable=False),
pyarrow.field("core_image_id", pyarrow.string(), nullable=False),
],
)
assert schema == expected_schema, "Generated schema does not match expected schema"
def test_pyarrow_to_glue_column() -> None:
schema = to_pyarrow_schema(TodoList)
columns = pyarrow_schema_to_glue_columns(schema=schema)
assert columns == [
{"Name": "bool_", "Type": "boolean"},
{"Name": "bool_list", "Type": "array<boolean>"},
{"Name": "float32_list", "Type": "array<float>"},
{"Name": "float32", "Type": "float"},
{"Name": "float64_list", "Type": "array<double>"},
{"Name": "float64", "Type": "double"},
{"Name": "int_list", "Type": "array<bigint>"},
{"Name": "int16_List", "Type": "array<smallint>"},
{"Name": "int16", "Type": "smallint"},
{"Name": "int32_list", "Type": "array<int>"},
{"Name": "int32", "Type": "int"},
{"Name": "int64_list", "Type": "array<bigint>"},
{"Name": "int64", "Type": "bigint"},
{"Name": "int8_list", "Type": "array<tinyint>"},
{"Name": "int8", "Type": "tinyint"},
{"Name": "str_list", "Type": "array<string>"},
{"Name": "string_list", "Type": "array<string>"},
{"Name": "string", "Type": "string"},
{"Name": "uint16_list", "Type": "array<int>"},
{"Name": "uint16", "Type": "int"},
{"Name": "uint32_list", "Type": "array<bigint>"},
{"Name": "uint32", "Type": "bigint"},
{"Name": "uint64_list", "Type": "array<bigint>"},
{"Name": "uint64", "Type": "bigint"},
{"Name": "uint8_list", "Type": "array<smallint>"},
{"Name": "uint8", "Type": "smallint"},
{"Name": "timestamp", "Type": "timestamp"},
{"Name": "date", "Type": "date"},
{"Name": "datetime", "Type": "timestamp"},
{"Name": "decimal", "Type": "decimal(28,0)"},
{"Name": "category", "Type": "map<tinyint,string>"},
{"Name": "heterogenous_category", "Type": "map<tinyint,string>"},
{"Name": "named_tuple", "Type": "struct<name:string,priority:bigint,pd_uint8:smallint,optional:string>"},
{
"Name": "named_tuple_list",
"Type": "array<struct<name:string,priority:bigint,pd_uint8:smallint,optional:string>>",
},
{
"Name": "typed_dict",
"Type": "struct<name:string,priority:bigint,pd_uint8:smallint,optional:string,int_or_none:bigint>",
},
{
"Name": "typed_dict_list",
"Type": "array<struct<name:string,priority:bigint,pd_uint8:smallint,optional:string,int_or_none:bigint>>",
},
{"Name": "map_str_str", "Type": "map<string,string>"},
{
"Name": "map_str_dict",
"Type": "map<string,struct<name:string,priority:bigint,pd_uint8:smallint,optional:string,int_or_none:bigint>>", # noqa: E501
},
{
"Name": "map_str_tuple",
"Type": "map<string,struct<name:string,priority:bigint,pd_uint8:smallint,optional:string>>",
},
]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment