Last active
April 2, 2024 01:06
-
-
Save sam-goodwin/85c44d0241f6848e4a183a39c1abfb58 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, Optional, get_args, get_origin | |
import numpy as np | |
import pandas as pd | |
import pandera.typing as pdt | |
import pyarrow as pa | |
from pandera import DataFrameModel, Index, MultiIndex, dtypes | |
from pandera.engines import numpy_engine, pandas_engine | |
from pandera.engines.engine import _is_namedtuple, _is_typeddict | |
from pandera.typing.common import SeriesBase | |
# see: https://github.com/unionai-oss/pandera/issues/689 - "generate pyarrow schema from pandera schema" | |
# forked from (un-merged): https://github.com/unionai-oss/pandera/pull/1047 | |
def to_pyarrow_schema( | |
model: type[DataFrameModel], | |
preserve_index: Optional[bool] = None, | |
) -> pa.Schema: | |
"""Convert a :class:`~pandera.schemas.DataFrameSchema` to `pa.Schema`. | |
:param dataframe_schema: schema to convert to `pa.Schema` | |
:param preserve_index: whether to store the index as an additional column | |
(or columns, for MultiIndex) in the resulting Table. The default of | |
None will store the index as a column, except for RangeIndex which is | |
stored as metadata only. Use `preserve_index=True` to force it to be | |
stored as a column. | |
:returns: `pa.Schema` representation of DataFrameSchema | |
""" | |
dataframe_schema = model.to_schema() | |
# List of columns that will be present in the pyarrow schema | |
columns: dict[str, SeriesBase] = dataframe_schema.columns # type: ignore[assignment] | |
# pyarrow schema metadata | |
metadata: dict[str, bytes] = {} | |
index = dataframe_schema.index | |
if index is None: | |
if preserve_index: | |
# Create column for RangeIndex | |
name = _get_index_name(0) | |
columns[name] = Index(dtypes.Int64, nullable=False, name=name) | |
else: | |
# Only preserve metadata of index | |
meta_val = b'[{"kind": "range", "name": pa.null, "step": 1}]' | |
metadata["index_columns"] = meta_val | |
elif preserve_index is not False: | |
# Add column(s) for index(es) | |
if isinstance(index, Index): | |
name = index.name or _get_index_name(0) | |
# Ensure index is added at dictionary beginning | |
columns = {name: index, **columns} | |
elif isinstance(index, MultiIndex): | |
for i, value in enumerate(reversed(index.indexes)): | |
name = value.name or _get_index_name(i) | |
columns = {name: value, **columns} | |
return pa.schema( | |
[to_pyarrow_field(k, v) for k, v in columns.items()], | |
metadata=metadata, | |
) | |
pandas_types = { | |
pd.BooleanDtype(): pa.bool_(), | |
pd.Int8Dtype(): pa.int8(), | |
pd.Int16Dtype(): pa.int16(), | |
pd.Int32Dtype(): pa.int32(), | |
pd.Int64Dtype(): pa.int64(), | |
pd.UInt8Dtype(): pa.uint8(), | |
pd.UInt16Dtype(): pa.uint16(), | |
pd.UInt32Dtype(): pa.uint32(), | |
pd.UInt64Dtype(): pa.uint64(), | |
pd.Float32Dtype(): pa.float32(), # type: ignore[attr-defined] | |
pd.Float64Dtype(): pa.float64(), # type: ignore[attr-defined] | |
pd.StringDtype(): pa.string(), | |
} | |
def to_pyarrow_field( | |
name: str, | |
pandera_field: SeriesBase, | |
) -> pa.Field: | |
"""Convert a :class:`~pandera.schema_components.SeriesBase` to a `pa.Field` | |
:param pandera_field: pandera Index or Column | |
:returns: `pa.Field` representation of `pandera_field` | |
""" | |
pandera_dtype = pandera_field.dtype | |
pyarrow_type = to_pyarrow_type(pandera_dtype) | |
return pa.field(name, pyarrow_type, pandera_field.nullable) | |
def to_pyarrow_type(pandera_dtype: Any) -> pa.DataType: | |
"""Convert a :class:`~pandera.schema_components.DataType` to a `pa.DataType` | |
:param pandera_dtype: pandera DataType | |
:returns: `pa.DataType` representation of `pandera_dtype` | |
""" | |
pandas_dtype = pandas_engine.Engine.dtype(pandera_dtype) | |
pandas_dtype_type = pandera_dtype.type | |
# if issubclass(pandas_dtype, pd.Int16Dtype()): | |
# pass | |
# if pandas_dtype in pandas_types: | |
# return pandas_types[pandera_dtype.type] | |
if isinstance(pandas_dtype_type, pandas_engine.Date | numpy_engine.DateTime64): | |
return pa.date64() | |
if isinstance(pandas_dtype_type, dtypes.Category): | |
# Categorical data types | |
return pa.dictionary( | |
pa.int8(), | |
pandera_dtype.type.categories.inferred_type, | |
ordered=pandera_dtype.ordered, # type: ignore[attr-defined] | |
) | |
# if (pandera_field) | |
if isinstance(pandas_dtype_type, type): | |
return type_to_arrow(pandera_dtype.generic_type) | |
if isinstance(hasattr(pandera_dtype, "special_type") and pandera_dtype.special_type, type): | |
return type_to_arrow(pandera_dtype.special_type) | |
if pandas_dtype.type == np.object_: | |
return pa.string() | |
return pa.from_numpy_dtype(pandas_dtype_type) | |
def type_to_arrow(python_type: type) -> pa.DataType: | |
if python_type is str: | |
return pa.string() | |
elif python_type is int: | |
return pa.int64() | |
elif python_type is float: | |
return pa.float64() | |
elif python_type is bool: | |
return pa.bool_() | |
# pandera types | |
elif python_type is pdt.UInt8: | |
return pa.uint8() | |
elif python_type is pdt.UInt16: | |
return pa.uint16() | |
elif python_type is pdt.UInt32: | |
return pa.uint32() | |
elif python_type is pdt.UInt64: | |
return pa.uint64() | |
elif python_type is pdt.Int8: | |
return pa.int8() | |
elif python_type is pdt.Int16: | |
return pa.int16() | |
elif python_type is pdt.Int32: | |
return pa.int32() | |
elif python_type is pdt.Int64: | |
return pa.int64() | |
elif python_type is pdt.Float32: | |
return pa.float32() | |
elif python_type is pdt.Float64: | |
return pa.float64() | |
elif python_type is pdt.String: | |
return pa.string() | |
elif python_type is pdt.Bool: | |
return pa.bool_() | |
# TODO: don't know what date32 is | |
# elif python_type is pd.DateOffset: | |
# return pa.date32() | |
elif python_type is pd.Timestamp: | |
return pa.timestamp("ns") | |
elif python_type is pd.Timedelta: | |
return pa.duration("ns") | |
elif python_type is pd.Categorical: | |
return pa.dictionary(pa.int8(), pa.string()) | |
elif python_type is pd.Interval or python_type is pd.Period or python_type is pd.Interval: | |
return pa.duration("ns") | |
# numpy types | |
elif python_type is np.datetime64: | |
return pa.timestamp("ns") | |
elif python_type is np.timedelta64: | |
return pa.duration("ns") | |
elif python_type is np.int8: | |
return pa.int8() | |
elif python_type is np.int16: | |
return pa.int16() | |
elif python_type is np.int32: | |
return pa.int32() | |
elif python_type is np.int64: | |
return pa.int64() | |
elif python_type is np.uint8: | |
return pa.uint8() | |
elif python_type is np.uint16: | |
return pa.uint16() | |
elif python_type is np.uint32: | |
return pa.uint32() | |
elif python_type is np.uint64: | |
return pa.uint64() | |
elif python_type is np.float32: | |
return pa.float32() | |
elif python_type is np.float64: | |
return pa.float64() | |
elif python_type is np.bool_: | |
return pa.bool_() | |
elif python_type is np.object_: | |
# TODO: is this right? | |
return pa.string() | |
elif get_origin(python_type) is list: | |
return pa.list_(type_to_arrow(get_args(python_type)[0])) | |
elif _is_namedtuple(python_type): | |
annotations = python_type.__annotations__.items() | |
fields = [] | |
fields = [ | |
pa.field( | |
key, | |
type_to_arrow(value), | |
# TODO(sgoodwin): determine this based on the type of Optional[T] or T | None | |
nullable=False, | |
) | |
for key, value in annotations | |
] | |
return pa.struct(fields) | |
elif _is_typeddict(python_type): | |
annotations = python_type.__annotations__.items() | |
error = f"Unsupported type: {python_type}" | |
raise TypeError(error) | |
def _get_index_name(level: int) -> str: | |
"""Generate an index name for pyarrow if none is specified""" | |
return f"__index_level_{level}__" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from typing import List, NamedTuple | |
import pandera.typing as pdt | |
import pyarrow | |
from pandera.typing import Series | |
from noetik_pipeline_methods.io.pandera import to_pyarrow_schema | |
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1" | |
import pandera as pa | |
class TodoItem(NamedTuple): | |
name: str | |
priority: int | |
# these aren't supported by pandera's to_schema() | |
# np_uint8: np.uint8 | |
pd_uint8: pdt.UInt8 | |
class TodoList(pa.DataFrameModel): | |
# TODO(sgoodwin): remove the redundant Series wrapper when fixed: | |
# https://github.com/unionai-oss/pandera/issues/1546 | |
bool_: Series[pdt.Bool] = pa.Field() | |
bool_list: Series[list[pdt.Bool]] = pa.Field() | |
float32_list: Series[list[pdt.Float32]] = pa.Field() | |
float32: Series[pdt.Float32] = pa.Field() | |
float64_list: Series[list[pdt.Float64]] = pa.Field() | |
float64: Series[pdt.Float64] = pa.Field() | |
int_list: Series[list[int]] = pa.Field() | |
int16_List: Series[List[pdt.Int16]] = pa.Field() | |
int16: Series[pdt.Int16] = pa.Field() | |
int32_list: Series[list[pdt.Int32]] = pa.Field() | |
int32: Series[pdt.Int32] = pa.Field() | |
int64_list: Series[list[pdt.Int64]] = pa.Field() | |
int64: Series[pdt.Int64] = pa.Field() | |
int8_list: Series[list[pdt.Int8]] = pa.Field() | |
int8: Series[pdt.Int8] = pa.Field() | |
str_list: Series[list[str]] = pa.Field() | |
string_list: Series[list[pdt.String]] = pa.Field() | |
string: Series[pdt.String] = pa.Field() | |
uint16_list: Series[list[pdt.UInt16]] = pa.Field() | |
uint16: Series[pdt.UInt16] = pa.Field() | |
uint32_list: Series[list[pdt.UInt32]] = pa.Field() | |
uint32: Series[pdt.UInt32] = pa.Field() | |
uint64_list: Series[list[pdt.UInt64]] = pa.Field() | |
uint64: Series[pdt.UInt64] = pa.Field() | |
uint8_list: Series[list[pdt.UInt8]] = pa.Field() | |
uint8: Series[pdt.UInt8] = pa.Field() | |
named_tuple: Series[TodoItem] = pa.Field() | |
def test_to_arrow(): | |
schema = to_pyarrow_schema(TodoList) | |
expected_schema = pyarrow.schema( | |
[ | |
pyarrow.field("bool_", pyarrow.bool_(), nullable=False), | |
pyarrow.field("bool_list", pyarrow.list_(pyarrow.bool_()), nullable=False), | |
pyarrow.field("float32_list", pyarrow.list_(pyarrow.float32()), nullable=False), | |
pyarrow.field("float32", pyarrow.float32(), nullable=False), | |
pyarrow.field("float64_list", pyarrow.list_(pyarrow.float64()), nullable=False), | |
pyarrow.field("float64", pyarrow.float64(), nullable=False), | |
pyarrow.field("int_list", pyarrow.list_(pyarrow.int64()), nullable=False), | |
pyarrow.field("int16_List", pyarrow.list_(pyarrow.int16()), nullable=False), | |
pyarrow.field("int16", pyarrow.int16(), nullable=False), | |
pyarrow.field("int32_list", pyarrow.list_(pyarrow.int32()), nullable=False), | |
pyarrow.field("int32", pyarrow.int32(), nullable=False), | |
pyarrow.field("int64_list", pyarrow.list_(pyarrow.int64()), nullable=False), | |
pyarrow.field("int64", pyarrow.int64(), nullable=False), | |
pyarrow.field("int8_list", pyarrow.list_(pyarrow.int8()), nullable=False), | |
pyarrow.field("int8", pyarrow.int8(), nullable=False), | |
pyarrow.field("str_list", pyarrow.list_(pyarrow.string()), nullable=False), | |
pyarrow.field("string_list", pyarrow.list_(pyarrow.string()), nullable=False), | |
pyarrow.field("string", pyarrow.string(), nullable=False), | |
pyarrow.field("uint16_list", pyarrow.list_(pyarrow.uint16()), nullable=False), | |
pyarrow.field("uint16", pyarrow.uint16(), nullable=False), | |
pyarrow.field("uint32_list", pyarrow.list_(pyarrow.uint32()), nullable=False), | |
pyarrow.field("uint32", pyarrow.uint32(), nullable=False), | |
pyarrow.field("uint64_list", pyarrow.list_(pyarrow.uint64()), nullable=False), | |
pyarrow.field("uint64", pyarrow.uint64(), nullable=False), | |
pyarrow.field("uint8_list", pyarrow.list_(pyarrow.uint8()), nullable=False), | |
pyarrow.field("uint8", pyarrow.uint8(), nullable=False), | |
pyarrow.field( | |
"named_tuple", | |
pyarrow.struct( | |
[ | |
pyarrow.field("name", pyarrow.string(), nullable=False), | |
pyarrow.field("priority", pyarrow.int64(), nullable=False), | |
pyarrow.field("pd_uint8", pyarrow.uint8(), nullable=False), | |
], | |
), | |
nullable=False, | |
), | |
], | |
) | |
assert schema == expected_schema, "Generated schema does not match expected schema" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment