Created
March 11, 2022 23:03
-
-
Save timothyslau/afe1b61e0d318185ab6b557149329bea to your computer and use it in GitHub Desktop.
Auto Downcast DataFrame dtypes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def type_cnvt(df, col, to, copy=True, errors="raise", bool_true=None, fillna=None): | |
import numpy as np | |
num_list = ["int8", "int16", "int32", "int64", "float16", "float32", "float64", "float128"] | |
# error can be "raise" or "ignore" or "coerce" | |
i8_min, i8_max = np.iinfo(np.int8)._min_vals.get("i8"), np.iinfo(np.int8)._max_vals.get("i8") | |
i16_min, i16_max = np.iinfo(np.int16)._min_vals.get("i16"), np.iinfo(np.int16)._max_vals.get("i16") | |
i32_min, i32_max = np.iinfo(np.int32)._min_vals.get("i32"), np.iinfo(np.int32)._max_vals.get("i32") | |
i64_min, i64_max = np.iinfo(np.int64)._min_vals.get("i64"), np.iinfo(np.int64)._max_vals.get("i64") | |
f16_min, f16_max = np.finfo(np.float16).min, np.finfo(np.float16).max | |
f32_min, f32_max = np.finfo(np.float32).min, np.finfo(np.float32).max | |
f64_min, f64_max = np.finfo(np.float64).min, np.finfo(np.float64).max | |
f128_min, f128_max = np.finfo(np.float128).min, np.finfo(np.float128).max | |
if (fillna != None): | |
inp = df.loc[:,col].fillna(value=fillna) | |
else: | |
inp = df.loc[:,col] | |
if ((to == "boolean") or (to == "bool")): | |
if ((inp.isnull().values.any()) & (errors == "raise")): | |
return warnings.warn("\n\n" + col + " " + "ERROR: missing values present: consider switching 'errors' kwarg \n(https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.astype.html)") | |
elif (((len(inp.unique()) > 2) or (len(inp.unique()) < 2)) & (bool_true == None) & (errors == "raise")): | |
return warnings.warn("\n\n" + col + " " + "ERROR: to = 'bool' but missing 2 unique values to be converted" + str(inp.unique()) + "\n(https://pandas.pydata.org/pandas-docs/stable/user_guide/basics.html#dtypes)") | |
elif ((bool_true == None) & (all(i in [0,1] for i in inp.unique()) != True) & (errors == "raise")): | |
return warnings.warn("\n\n" + col + " " + "ERROR: unique binary values are not 0 & 1; must set kwarg 'bool_true'") | |
elif ((inp.isnull().values.any()) & (fillna == None) & (errors == "coerce")): | |
return np.select(choicelist=[inp == 1, inp == 0], condlist=[True, False], default=None) | |
elif (bool_true != None): | |
return np.select(choicelist=[inp == bool_true], condlist=[True], default=False) | |
else: | |
return inp.astype(dtype="bool", copy=copy, errors=errors) | |
elif ((to == "integer") or (to == "int")): | |
if ((any(i not in num_list for i in [inp.dtype]))): | |
return warnings.warn("\n\n" + col + " " + "ERROR: col dtype is not in " + str(num_list)) | |
elif ((inp.isnull().values.any()) & (errors == "raise")): | |
return warnings.warn("\n\n" + col + " " + "ERROR: missing values present: consider switching 'errors' kwarg \n()") | |
elif ((inp.isnull().values.any()) & (errors == "coerce")): | |
return pd.to_numeric(arg=inp, downcast="integer", errors=errors) | |
elif ((inp.max() < i8_max) & (inp.min() > i8_min)): | |
return inp.astype(dtype="int8", copy=copy, errors=errors) | |
elif ((inp.max() < i16_max) & (inp.min() > i16_min)): | |
return inp.astype(dtype="int16", copy=copy, errors=errors) | |
elif ((inp.max() < i32_max) & (inp.min() > i32_min)): | |
return inp.astype(dtype="int32", copy=copy, errors=errors) | |
elif ((inp.max() < i64_max) & (inp.min() > i64_min)): | |
return inp.astype(dtype="int64", copy=copy, errors=errors) | |
else: | |
return warnings.warn("\n\n" + col + " " + "ERROR: integer values do not fall within int limits \n(https://numpy.org/doc/stable/reference/generated/numpy.iinfo.html#numpy-iinfo)") | |
elif ((to == "float") or (to == "flt")): | |
if ((any(i not in num_list for i in [inp.dtype]))): | |
return warnings.warn("\n\n" + col + " " + "ERROR: col dtype is not in " + str(num_list)) | |
elif ((inp.isnull().values.any()) & (errors == "raise")): | |
return warnings.warn("\n\n" + col + " " + "ERROR: missing values present: consider switching 'errors' kwarg \n()") | |
elif ((inp.isnull().values.any()) & (errors == "coerce")): | |
return pd.to_numeric(arg=inp, downcast="float", errors=errors) | |
## parquet doesn't support float half-precision ## | |
elif ((inp.max() < f16_max) & (inp.min() > f16_min)): | |
return inp.astype(dtype="float16", copy=copy, errors=errors) | |
elif ((inp.max() < f32_max) & (inp.min() > f32_min)): | |
return inp.astype(dtype="float32", copy=copy, errors=errors) | |
elif ((inp.max() < f64_max) & (inp.min() > f64_min)): | |
return inp.astype(dtype="float64", copy=copy, errors=errors) | |
elif ((inp.max() < f128_max) & (inp.min() > f128_min)): | |
return inp.astype(dtype="float128", copy=copy, errors=errors) | |
else: | |
return warnings.warn("\n\n" + col + " " + "ERROR: float values do not fall within float limits \n(https://numpy.org/doc/stable/reference/generated/numpy.finfo.html#numpy-finfo)") | |
else: | |
return warnings.warn("\n\n" + col + " " + "ERROR: to not in [boolean, integer, float]") | |
# EXAMPLE | |
for i in list(d1.columns[["_flg" in i for i in d1.columns]]): | |
d1[i] = type_cnvt(df=d1, col=i, to="int8", errors="raise") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment