Last active
October 23, 2020 10:47
-
-
Save leriomaggio/f436dc9cbfd5ee0abc099ded09109591 to your computer and use it in GitHub Desktop.
PyTrousse Protocol-based Data Pipeline with Simple Validation Algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple, List, Optional, Union, NewType | |
from typing import Protocol, runtime_checkable | |
from abc import ABC, abstractmethod | |
import pandas as pd | |
try: | |
import torch as t | |
from torch import Tensor as Tensor | |
except ImportError: # fallback to Numpy, if torch not available - this is just as a POP! | |
import numpy as t | |
from numpy import ndarray as Tensor | |
PandasDataset = NewType("PandasDataset", pd.DataFrame) | |
TensorDataset = NewType("TensorDataset", Tensor) | |
Dataset = Union[PandasDataset, TensorDataset] | |
# Protocols | |
# ========= | |
@runtime_checkable | |
class DataFrameProtocol(Protocol): | |
def apply_pandas(self, ds: PandasDataset) -> PandasDataset: | |
... | |
@runtime_checkable | |
class TensorProtocol(Protocol): | |
def apply_torch(self, ds: TensorDataset) -> TensorDataset: | |
... | |
# --------------------------------------------------------------- | |
# Feature Operation ABC | |
# ===================== | |
class FeatureOperation(ABC): | |
def __call__(self, ds: Dataset) -> Dataset: | |
# do stuff | |
ds = self.apply(ds) | |
# do other stuff | |
return ds | |
@abstractmethod | |
def apply(self, ds: Dataset) -> Dataset: | |
... | |
def __str__(self) -> str: | |
return self.__class__.__name__ | |
# --------------------------------------------------------------- | |
# Concrete Feature Operation Implementation w/ Protocol Adherence | |
# =============================================================== | |
class DataFrameIdentityOperation(FeatureOperation): | |
def apply(self, ds: Dataset) -> Dataset: | |
return self.apply_pandas(ds) | |
def apply_pandas(self, ds: PandasDataset) -> PandasDataset: | |
print("Pandas Dataset") | |
return ds | |
class TensorIdentityOperation(FeatureOperation): | |
def apply(self, ds: Dataset) -> Dataset: | |
return self.apply_torch(ds) | |
def apply_torch(self, ds: TensorDataset) -> TensorDataset: | |
print("Tensor Dataset") | |
return ds | |
# --------------------------------------------------------------- | |
# Trousse Pipeline | |
# ================ | |
class TrousseValidationError(RuntimeError): | |
pass | |
class Trousse: | |
def __init__(self, *ops: FeatureOperation) -> None: | |
self._operations = list() | |
for op in ops: | |
self._operations.append(op) | |
def add_module(self, op: FeatureOperation) -> None: | |
self._operations.append(op) | |
@property | |
def operations(self): | |
return self._operations | |
def __call__(self, ds: Dataset) -> Dataset: | |
valid, errors = self.is_valid(ds, with_errors=True) | |
if not valid: | |
raise TrousseValidationError(f"Invalid Ops: {errors}") | |
for op in self._operations: | |
ds = op(ds) | |
return ds | |
def is_valid(self, ds: Dataset, with_errors: bool = False) -> Tuple[bool, Optional[List[str]]]: | |
if isinstance(ds, pd.DataFrame): # subs w/ real PandasDataset | |
invalid_ops = filter(lambda op: not(isinstance(op, DataFrameProtocol)), self._operations) | |
elif isinstance(ds, Tensor): # subs w/ real TensorDataset | |
invalid_ops = filter(lambda op: not(isinstance(op, TensorProtocol)), self._operations) | |
else: | |
invalid_ops = ("Unsupported Dataset", ) | |
if with_errors: | |
str_ops = [str(iop) for iop in invalid_ops] | |
return len(str_ops) == 0, str_ops | |
else: | |
return len(list(invalid_ops)) == 0, None | |
def __str__(self) -> str: | |
return f"Trousse: {[str(op) for op in self._operations]}" | |
if __name__ == "__main__": | |
di = DataFrameIdentityOperation() | |
ti = TensorIdentityOperation() | |
print(f"{di} implements PandasProtocol: {isinstance(di, DataFrameProtocol)}") | |
print(f"{di} implements TorchProtocol: {isinstance(di, TensorProtocol)}") | |
print(f"{ti} implements PandasProtocol: {isinstance(ti, DataFrameProtocol)}") | |
print(f"{ti} implements TorchProtocol: {isinstance(ti, TensorProtocol)}") | |
print(f"{di} is Callable: {callable(di)}") | |
print(f"{ti} is Callable: {callable(ti)}") | |
from string import punctuation | |
df = pd.DataFrame({"Char": list(punctuation), "No": list(range(len(punctuation)))}) | |
ds_pd = PandasDataset(df) | |
ds_to = TensorDataset(t.ones((3, 4))) | |
trousse = Trousse(DataFrameIdentityOperation()) | |
print(f"Pipeline: {str(trousse)}") | |
print("Trousse is valid w/ Pandas Dataset: ", trousse.is_valid(ds_pd)) | |
print("Trousse is valid w/ Tensor Dataset: ", trousse.is_valid(ds_to)) | |
print("Run Pipeline") | |
trousse(ds_pd) | |
try: | |
print("Raise Error: ") | |
trousse(ds_to) | |
except TrousseValidationError as e: | |
print(f"Exception ==> {e}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment