Created
February 7, 2023 20:09
-
-
Save NickCrews/02738992caa52e6075c3974ea8c57ebf to your computer and use it in GitHub Desktop.
Round-tripping Pandas -> Ibis -> Pandas
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
AnyColOrTable = TypeVar("AnyColOrTable", Column, Table, pd.Series, pd.DataFrame) | |
def convert_to_ibis( | |
func: Callable[[ColOrTable], ColOrTable] | |
) -> Callable[[AnyColOrTable], AnyColOrTable]: | |
"""Decorator that translates pandas series to Columns and DFs to Tables, | |
applies the function, and then converts back to pandas.""" | |
@functools.wraps(func) | |
def wrapper(inp, *args, **kwargs): | |
if isinstance(inp, pd.DataFrame): | |
func_inp = df_to_memtable(inp) | |
elif isinstance(inp, pd.Series): | |
df = pd.DataFrame({"value": inp}) | |
func_inp = df_to_memtable(df)["value"] | |
else: | |
func_inp = inp | |
result = func(func_inp, *args, **kwargs) | |
if isinstance(inp, (pd.Series, pd.DataFrame)): | |
p = expr_to_pandas(result) | |
p.index = inp.index | |
return p | |
else: | |
return result | |
return wrapper | |
def _get_pandas_ibis_dt(pdt): | |
m = { | |
pd.UInt8Dtype(): (float, "uint8"), | |
pd.UInt16Dtype(): (float, "uint16"), | |
pd.UInt32Dtype(): (float, "uint32"), | |
pd.UInt64Dtype(): (float, "uint64"), | |
pd.Int8Dtype(): (float, "int8"), | |
pd.Int16Dtype(): (float, "int16"), | |
pd.Int32Dtype(): (float, "int32"), | |
pd.Int64Dtype(): (float, "int64"), | |
} | |
default = (pdt, None) | |
return m.get(pdt, default) | |
def df_to_memtable(df: pd.DataFrame) -> Table: | |
"""Convert df to an Ibis memtable, with better casting. | |
Int64Dtype (and similar) is not supported by Ibis (ie it crashes) | |
so we need to convert it to float, | |
convert to memtable, then cast back to int64 once it's in Ibis. | |
https://github.com/ibis-project/ibis/issues/5343""" | |
pre_convert_dtypes = {} | |
post_convert_dtypes = {} | |
for col, orig_dt in df.dtypes.items(): | |
pdt, ibis_dt = _get_pandas_ibis_dt(orig_dt) | |
pre_convert_dtypes[col] = pdt | |
if ibis_dt is not None: | |
post_convert_dtypes[col] = ibis_dt | |
df2 = df.astype(pre_convert_dtypes) | |
t = ibis.memtable(df2) | |
m = {col: t[col].cast(ibis_dt) for col, ibis_dt in post_convert_dtypes.items()} | |
return t.mutate(**m) | |
def _ibis_to_pandas_dt(dt): | |
m = { | |
idt.String(): pd.StringDtype(), | |
idt.Boolean(): pd.BooleanDtype(), | |
idt.UInt8(): pd.UInt8Dtype(), | |
idt.UInt16(): pd.UInt16Dtype(), | |
idt.UInt32(): pd.UInt32Dtype(), | |
idt.UInt64(): pd.UInt64Dtype(), | |
idt.Int8(): pd.Int8Dtype(), | |
idt.Int16(): pd.Int16Dtype(), | |
idt.Int32(): pd.Int32Dtype(), | |
idt.Int64(): pd.Int64Dtype(), | |
} | |
return m.get(dt, dt) | |
def expr_to_pandas(t: Expr) -> pd.DataFrame | pd.Series: | |
"""Convert an Ibis expression to a pandas DataFrame or Series. | |
by default strings are converted to object, | |
and ant int cols with nulls are converted to float. | |
https://github.com/ibis-project/ibis/issues/5316 | |
""" | |
raw = t.execute() | |
if isinstance(raw, pd.Series): | |
return raw.astype(_ibis_to_pandas_dt(t.type())) | |
else: | |
m = {col: _ibis_to_pandas_dt(dt) for col, dt in t.schema().items()} | |
return raw.astype(m) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment