Created
September 26, 2024 15:03
-
-
Save cbowdon/c0a93b026f40f9fe6afacc7d13eddb40 to your computer and use it in GitHub Desktop.
Static typing for Polars dataframes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
######### | |
# Static typing for Polars dataframes | |
# | |
# Motivation: to have some visibility and static checks on the schema of DataFrames at boundaries | |
# Approach: subclass of DataFrame with schema type parameter | |
# Benefit: structural typing, runtime checks for agreement with declared types | |
# Tradeoffs: types not inferred from arbitrary transforms, must always declare | |
from polars import DataFrame, col, lit | |
from typing import TypedDict, TypeVar, Mapping | |
from polars._typing import PythonDataType, PolarsDataType, SchemaDict, Orientation | |
from polars._typing import FrameInitTypes # pyright: ignore[reportUnknownVariableType] | |
from polars.datatypes import N_INFER_DEFAULT | |
########################## | |
#### Type definitions #### | |
########################## | |
class STypedDict(TypedDict): | |
"Dummy subclass required because TypedDict is actually a function not a type" | |
... | |
TSchema = TypeVar("TSchema", bound=STypedDict) | |
# A type arg for generic classes, bound to be a subclass of TypedDict | |
class DF[TSchema](DataFrame): | |
"A subclass of Polars DataFrame but with a generic type parameter that is the schema." | |
# Polars doesn't support subclassing, so perhaps redesign this as a wrapper | |
def __init__( | |
self, | |
data: FrameInitTypes | None, # pyright: ignore[reportUnknownParameterType] | |
schema: type[TSchema], | |
*, | |
schema_overrides: SchemaDict | None = None, | |
strict: bool = True, | |
orient: Orientation | None = None, | |
infer_schema_length: int | None = N_INFER_DEFAULT, | |
nan_to_null: bool = False, | |
) -> None: | |
"As Polars constructor, but data and schema are mandatory, and schema must be type paramter." | |
# Now construct a Polars DataFrame from the data and our typed schema. | |
polars_schema: Mapping[str, PythonDataType | PolarsDataType] = ( | |
schema.__annotations__ | |
) | |
# Polars does a runtime check here | |
super().__init__(data, polars_schema) # type: ignore | |
################## | |
#### Examples #### | |
################## | |
class Event(TypedDict): | |
ident: str | |
summary: str | |
class ScoredEvent( | |
TypedDict | |
): # we could also inherit from Event, but it works structurally | |
ident: str | |
summary: str | |
score: float | |
event_df = DF( | |
# dict(x=[1,2,3], y=[0.1, 0.2, 0.3]), # type checks, but blows up with "column names don't match data dictionary" | |
dict(ident=["123"], summary=["blah"]), | |
schema=Event, | |
) | |
scored_event_df = DF( | |
dict(ident=["123"], summary=["blah"], score=[1.0]), schema=ScoredEvent | |
) | |
df_down: DF[Event] = scored_event_df | |
df_up: DF[ScoredEvent] = ( | |
event_df # type DF[Event] is not assignable to type DF[ScoredEvent] | |
) | |
def f_add_col(df: DF[Event]) -> DF[ScoredEvent]: | |
# it's not smart enough to recognise the transform itself | |
# and polars intentionally does not support subclassing | |
# so we have to cast it ourselves | |
return DF(df.with_columns(score=lit(1.0)), schema=ScoredEvent) | |
f_add_col(event_df) | |
def f_add_wrong_col(df: DF[Event]) -> DF[ScoredEvent]: | |
# if you declare something untrue it blows up at runtime | |
return DF(df.with_columns(score=lit("piano")), schema=ScoredEvent) # runtime error | |
f_add_wrong_col(event_df) # error! | |
def f_structural(df: DF[ScoredEvent]) -> DF[Event]: | |
return DF(df.filter(col("score") > 0).select("ident", "summary"), schema=Event) | |
f_structural(event_df) # type not ok | |
f_structural(scored_event_df) # type is ok |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment