Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save leriomaggio/f436dc9cbfd5ee0abc099ded09109591 to your computer and use it in GitHub Desktop.
Save leriomaggio/f436dc9cbfd5ee0abc099ded09109591 to your computer and use it in GitHub Desktop.
PyTrousse Protocol-based Data Pipeline with Simple Validation Algorithm
from typing import Tuple, List, Optional, Union, NewType
from typing import Protocol, runtime_checkable
from abc import ABC, abstractmethod
import pandas as pd
try:
import torch as t
from torch import Tensor as Tensor
except ImportError: # fallback to Numpy, if torch not available - this is just as a POP!
import numpy as t
from numpy import ndarray as Tensor
PandasDataset = NewType("PandasDataset", pd.DataFrame)
TensorDataset = NewType("TensorDataset", Tensor)
Dataset = Union[PandasDataset, TensorDataset]
# Protocols
# =========
@runtime_checkable
class DataFrameProtocol(Protocol):
def apply_pandas(self, ds: PandasDataset) -> PandasDataset:
...
@runtime_checkable
class TensorProtocol(Protocol):
def apply_torch(self, ds: TensorDataset) -> TensorDataset:
...
# ---------------------------------------------------------------
# Feature Operation ABC
# =====================
class FeatureOperation(ABC):
def __call__(self, ds: Dataset) -> Dataset:
# do stuff
ds = self.apply(ds)
# do other stuff
return ds
@abstractmethod
def apply(self, ds: Dataset) -> Dataset:
...
def __str__(self) -> str:
return self.__class__.__name__
# ---------------------------------------------------------------
# Concrete Feature Operation Implementation w/ Protocol Adherence
# ===============================================================
class DataFrameIdentityOperation(FeatureOperation):
def apply(self, ds: Dataset) -> Dataset:
return self.apply_pandas(ds)
def apply_pandas(self, ds: PandasDataset) -> PandasDataset:
print("Pandas Dataset")
return ds
class TensorIdentityOperation(FeatureOperation):
def apply(self, ds: Dataset) -> Dataset:
return self.apply_torch(ds)
def apply_torch(self, ds: TensorDataset) -> TensorDataset:
print("Tensor Dataset")
return ds
# ---------------------------------------------------------------
# Trousse Pipeline
# ================
class TrousseValidationError(RuntimeError):
pass
class Trousse:
def __init__(self, *ops: FeatureOperation) -> None:
self._operations = list()
for op in ops:
self._operations.append(op)
def add_module(self, op: FeatureOperation) -> None:
self._operations.append(op)
@property
def operations(self):
return self._operations
def __call__(self, ds: Dataset) -> Dataset:
valid, errors = self.is_valid(ds, with_errors=True)
if not valid:
raise TrousseValidationError(f"Invalid Ops: {errors}")
for op in self._operations:
ds = op(ds)
return ds
def is_valid(self, ds: Dataset, with_errors: bool = False) -> Tuple[bool, Optional[List[str]]]:
if isinstance(ds, pd.DataFrame): # subs w/ real PandasDataset
invalid_ops = filter(lambda op: not(isinstance(op, DataFrameProtocol)), self._operations)
elif isinstance(ds, Tensor): # subs w/ real TensorDataset
invalid_ops = filter(lambda op: not(isinstance(op, TensorProtocol)), self._operations)
else:
invalid_ops = ("Unsupported Dataset", )
if with_errors:
str_ops = [str(iop) for iop in invalid_ops]
return len(str_ops) == 0, str_ops
else:
return len(list(invalid_ops)) == 0, None
def __str__(self) -> str:
return f"Trousse: {[str(op) for op in self._operations]}"
if __name__ == "__main__":
di = DataFrameIdentityOperation()
ti = TensorIdentityOperation()
print(f"{di} implements PandasProtocol: {isinstance(di, DataFrameProtocol)}")
print(f"{di} implements TorchProtocol: {isinstance(di, TensorProtocol)}")
print(f"{ti} implements PandasProtocol: {isinstance(ti, DataFrameProtocol)}")
print(f"{ti} implements TorchProtocol: {isinstance(ti, TensorProtocol)}")
print(f"{di} is Callable: {callable(di)}")
print(f"{ti} is Callable: {callable(ti)}")
from string import punctuation
df = pd.DataFrame({"Char": list(punctuation), "No": list(range(len(punctuation)))})
ds_pd = PandasDataset(df)
ds_to = TensorDataset(t.ones((3, 4)))
trousse = Trousse(DataFrameIdentityOperation())
print(f"Pipeline: {str(trousse)}")
print("Trousse is valid w/ Pandas Dataset: ", trousse.is_valid(ds_pd))
print("Trousse is valid w/ Tensor Dataset: ", trousse.is_valid(ds_to))
print("Run Pipeline")
trousse(ds_pd)
try:
print("Raise Error: ")
trousse(ds_to)
except TrousseValidationError as e:
print(f"Exception ==> {e}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment