Last active
January 17, 2022 17:57
-
-
Save ianliu/9c14ae6fa6a786e0e49a7ab9c540892a to your computer and use it in GitHub Desktop.
Trickery to operate on dataframe's columns without lambdas
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import operator as op | |
from functools import partial | |
from datetime import datetime | |
def eval(node, df): | |
if not isinstance(node, tuple | str): | |
return node | |
if isinstance(node, str): | |
return df[node] | |
operand = node[0] | |
if operand == "lit": | |
return node[1] | |
elif operand == "call": | |
fn, args, kwargs = node[1:] | |
return eval(fn, df)( | |
*[eval(arg.ast, df) if isinstance(arg, col) else arg for arg in args], | |
**{k: eval(v, df) if isinstance(v, col) else v for k, v in kwargs.items()}) | |
else: | |
return operand(*map(partial(eval, df=df), node[1:])) | |
class col: | |
@classmethod | |
def new(cls, fn, *args, **kwargs): | |
return cls(("call", ("lit", fn), args, kwargs)) | |
def __init__(self, ast): | |
self.ast = ast | |
def call(self, *args, **kwargs): | |
return col(("call", self.ast, args, kwargs)) | |
def __call__(self, df): | |
return eval(self.ast, df) | |
def __sub__(self, rhs): | |
return col((op.sub, self.ast, rhs.ast if isinstance(rhs, col) else rhs)) | |
def __add__(self, rhs): | |
return col((op.add, self.ast, rhs.ast if isinstance(rhs, col) else rhs)) | |
def __mul__(self, rhs): | |
return col((op.mul, self.ast, rhs.ast if isinstance(rhs, col) else rhs)) | |
def __pow__(self, rhs): | |
return col((op.pow, self.ast, rhs.ast if isinstance(rhs, col) else rhs)) | |
def __truediv__(self, rhs): | |
return col((op.truediv, self.ast, rhs.ast if isinstance(rhs, col) else rhs)) | |
def __neg__(self): | |
return col((op.neg, self.ast)) | |
def __getattr__(self, name): | |
return col((getattr, self.ast, ("lit", name))) | |
data = [ | |
dict(name="Ian", height=1.82, weight=79, birth=datetime(1988, 3, 1)), | |
dict(name="Giu", height=1.70, weight=58, birth=datetime(1987, 2, 28)), | |
dict(name="Nana", height=1.90, weight=120, birth=datetime(1970, 1, 1)), | |
] | |
print(pd.DataFrame(data) | |
.assign(bmi=col("weight") / col("height") ** 2, | |
name=col("name").str.lower.call(), | |
year=col("birth").dt.year, | |
category=col.new(pd.cut, | |
col("bmi"), | |
bins=[0, 18.4, 24.9, 100], | |
labels=["underweight", "normal", "overweight"]) | |
)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment