Last active
August 29, 2015 14:27
-
-
Save hcarvalhoalves/993f313d3b0a78fadc8f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import collections | |
import toolz | |
import decorator | |
## same boilerplate as in common_models | |
class Stage(object): | |
def __init__(self, **constants): | |
self.constants = constants | |
def do(self, *args, **kwargs): | |
raise NotImplementedError | |
class Combinator(Stage): | |
def __init__(self, *stages): | |
self.stages = stages | |
class Pipeline(Stage): | |
def __init__(self, *stages): | |
self.stages = stages | |
def run(stage, **initial): | |
def recur(stage, prev): | |
if isinstance(stage, Combinator): | |
return map(lambda s: recur(s, prev), stage.stages) | |
if isinstance(stage, Pipeline): | |
return reduce(lambda stack, stage: recur(stage, stack), | |
stage.stages, prev) | |
if isinstance(stage, Stage): | |
if isinstance(prev, collections.Mapping): | |
return stage.do(**prev) | |
return stage.do(*prev) | |
raise NotImplementedError(type(stage)) | |
return recur(stage, initial) | |
## decorator for lazy me ,,_o.O_,, | |
def stage(f): | |
def factory(**kwargs): | |
constants = dict(filter(lambda k: not isinstance(k[1], Stage), kwargs.items())) | |
stages = dict(filter(lambda k: isinstance(k[1], Stage), kwargs.items())) | |
f_kwargs = stages.keys() | |
def wrap(func): | |
def _do(self, *args): | |
return func(**toolz.merge(self.constants, dict(zip(f_kwargs, args)))) | |
return _do | |
class _S(Stage): | |
def __repr__(self): | |
return "{}({})".format(f.func_name, repr(self.constants)) | |
do = wrap(f) | |
return Pipeline(Combinator(*stages.values()), | |
_S(**constants)) | |
return factory | |
## test transparent composition for the win | |
@stage | |
def operate(x=0, y=0): | |
return x + y | |
@stage | |
def integrate(**kwargs): | |
return sum(kwargs.values()) | |
def test_compose(): | |
# unary | |
assert run(operate(x=1)) == 1 | |
# nested | |
pipeline = operate(x=2, y=operate(x=1)) | |
assert run(pipeline) == 3 | |
# deeply nested | |
another_pipeline = operate(x=3, y=pipeline) | |
all_pipes = operate(x=pipeline, y=another_pipeline) | |
assert run(all_pipes) == 9 | |
summary = integrate(x=pipeline, y=another_pipeline) | |
assert run(summary) == 9 | |
## mimmick an imperative API for great success | |
# @stage | |
# def train_X(self): | |
# return {'dataframe': [['foo', 1], ['bar', 2]]} | |
# @stage | |
# def train_y(self): | |
# return {'dataframe': [True, False]} | |
# @stage | |
# def actual_X(self): | |
# return {'dataframe': [['foo', 3], ['baz', 0]]} | |
# @stage | |
# def learner(self, X, y): | |
# return {'rules': {fst: lst | |
# for (fst, _), lst in zip(X['dataframe'], y['dataframe'])}} | |
# @stage | |
# def predictor(self, X, classifier): | |
# missing = self.constants['missing'] | |
# return {'dataframe': [classifier['rules'].get(fst, missing) | |
# for (fst, _) in X['dataframe']]} | |
# @stage | |
# def wrapper(self): | |
# return self.constants | |
# class ShittyClassifier(object): | |
# def __init__(self, missing=None): | |
# self.missing = missing | |
# def fit(self, X, y): | |
# self.classifier = (X + y) | learner() # WTF? | |
# return self.classifier | |
# def predict(self, X): | |
# return (X + self.classifier) | predictor(missing=self.missing) | |
# def load(self, **kwargs): | |
# self.classifier = wrapper(**kwargs) # Double-WTF? | |
# ## thanks to voodoo, everything lazy until `run` | |
# def test_scikit_like_api(): | |
# sc = ShittyClassifier(missing='NaNaNaN') | |
# X, y = train_X(), train_y() | |
# classifier = sc.fit(X, y) | |
# assert run(classifier) == {'rules': {'foo': True, 'bar': False}} | |
# test_X = actual_X() | |
# final = sc.predict(test_X) | |
# assert run(final) == {'dataframe': [True, 'NaNaNaN']} | |
# computed_rules = run(classifier) | |
# other_sc = ShittyClassifier(missing='Batman!') | |
# other_sc.load(**computed_rules) | |
# more = other_sc.predict(test_X) | |
# assert run(more) == {'dataframe': [True, 'Batman!']} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment