Created
July 2, 2021 17:18
-
-
Save tommct/ef7309d7162da649849af43f3cc3fe12 to your computer and use it in GitHub Desktop.
Allows for the retrieval of all or parts of the transformations in a sklearn Pipeline, as well as the ability to dynamically bypass parts of the pipeline.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import contextlib | |
from functools import partial | |
from sklearn.pipeline import Pipeline | |
@contextlib.contextmanager | |
def intermediate_transforms(pipe: Pipeline, keys: list=[], bypass_list: list=[]): | |
"""Allows for the retrieval of all or parts of the transformations in a | |
sklearn Pipeline, as well as the ability to dynamically bypass parts of | |
the pipeline. | |
Within the context, intermediate results are available as a dict, `pipe.intermediate_results__` | |
with keys equal to the names of the transformers in the pipeline. This dict is a temporary structure | |
and only available within the context. | |
Args: | |
pipe (Pipeline): [Pipeline](https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html) | |
object. | |
keys (list, optional): List of pipeline object to retrieve. If empty, then | |
all are available in the pipe's `intermediate_results__` within this | |
context. Otherwise, just the names listed are captured. Defaults to []. | |
bypass_list (list, optional): List of names to bypass. When not empty, this is | |
the set of names in the pipeline to "passthrough" if these names are at | |
in a pipeline or "drop" if in a FeatureUnion. Defaults to []. | |
Example: | |
We instantiate a simple pipeline. | |
import numpy as np | |
import pandas as pd | |
from sklearn.pipeline import Pipeline | |
from sklearn.base import TransformerMixin | |
df = pd.DataFrame({'name': ['Anne', 'Bob', 'Charlie', 'Bob'], | |
'age': [20, 21, 22, 23]}) | |
class LowercaseTransformer(TransformerMixin): | |
def transform(self, X): | |
return X.apply(lambda x: x.lower()) | |
class UppercaseTransformer(TransformerMixin): | |
def transform(self, X): | |
return X.apply(lambda x: x.upper()) | |
class CamelcaseTransformer(TransformerMixin): | |
def transform(self, X): | |
return X.apply(lambda x: x[0].upper() + x[1:].lower()) | |
class ReverseTransformer(TransformerMixin): | |
def transform(self, X): | |
try: | |
return X.applymap(lambda x: x[::-1]) | |
except AttributeError: | |
return X.apply(lambda x: x[::-1]) | |
lct = LowercaseTransformer() | |
uct = UppercaseTransformer() | |
cct = CamelcaseTransformer() | |
rt = ReverseTransformer() | |
pipe = Pipeline([('lower', lct), ('upper', uct), ('reverse', rt), ('camel', cct), ('last', 'passthrough')]) | |
Then we can execute some contexts. | |
# To retrieve all intermediate results... | |
with intermediate_transforms(pipe): | |
Xt = pipe.transform(df['name']) | |
intermediate_results = pipe.intermediate_results__ | |
Outputs: | |
{'lower': 0 anne | |
1 bob | |
2 charlie | |
3 bob | |
Name: name, dtype: object, | |
'upper': 0 ANNE | |
1 BOB | |
2 CHARLIE | |
3 BOB | |
Name: name, dtype: object, | |
'reverse': 0 ENNA | |
1 BOB | |
2 EILRAHC | |
3 BOB | |
Name: name, dtype: object, | |
'camel': 0 Enna | |
1 Bob | |
2 Eilrahc | |
3 Bob | |
Name: name, dtype: object} | |
To retrieve the first few steps, we can execute the following. Note that in this case, | |
the order of the keys does not matter, but the returned transform, `Xt`, will be the results | |
of the last transformer in our keys. And `intermediate_results` contains only the keys of interest. | |
with intermediate_transforms(pipe, keys=['upper', 'lower']): | |
Xt = pipe.transform(df['name']) | |
intermediate_results = pipe.intermediate_results__ | |
This provides: | |
{'lower': 0 anne | |
1 bob | |
2 charlie | |
3 bob | |
Name: name, dtype: object, | |
'upper': 0 ANNE | |
1 BOB | |
2 CHARLIE | |
3 BOB | |
Name: name, dtype: object} | |
To bypass/passthrough/drop transformers, we can execute this context. This may be useful | |
in FeautureUnions to ignore some paths. However, skipping transformations is likely to give | |
unexpected final results. | |
with intermediate_transforms(pipe, bypass_list=['camel']): | |
Xt = pipe.transform(df['name']) | |
intermediate_results = pipe.intermediate_results__ | |
This provides the following output: | |
{'lower': 0 anne | |
1 bob | |
2 charlie | |
3 bob | |
Name: name, dtype: object, | |
'upper': 0 ANNE | |
1 BOB | |
2 CHARLIE | |
3 BOB | |
Name: name, dtype: object, | |
'reverse': 0 ENNA | |
1 BOB | |
2 EILRAHC | |
3 BOB | |
Name: name, dtype: object} | |
""" | |
# Our temporary overload of Pipeline._transform() method. | |
# https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/pipeline.py | |
def _pipe_transform(self, X): | |
Xt = X | |
for _, name, transform in self._iter(): | |
Xt = transform.transform(Xt) | |
if self.intermediate_keys__ is not None: | |
try: | |
self.intermediate_keys__.remove(name) | |
self.intermediate_results__[name] = Xt | |
if len(self.intermediate_keys__) == 0: | |
break | |
except ValueError: | |
pass | |
else: | |
self.intermediate_results__[name] = Xt | |
return Xt | |
if not isinstance(pipe, Pipeline): | |
raise ValueError(f'"{pipe}" must be a Pipeline.') | |
pipe.intermediate_results__ = {} | |
pipe.intermediate_keys__ = None | |
if keys: | |
pipe.intermediate_keys__ = keys | |
before_list_objs = {} | |
bypass_list_objs = {} | |
if bypass_list: | |
params = pipe.get_params() | |
for k in bypass_list: | |
if k in params: | |
before_list_objs[k] = params[k] | |
if 0 < k.find('__') < len(k): | |
bypass_list_objs[k] = 'drop' # FeatureUnion bypass | |
else: | |
bypass_list_objs[k] = 'passthrough' # Pipeline bypass | |
if bypass_list_objs: | |
pipe.set_params(**bypass_list_objs) | |
_transform_before = pipe._transform | |
pipe._transform = partial(_pipe_transform, pipe) # Monkey-patch our _pipe_transform method. | |
yield pipe # Release our patched object to the context | |
# Restore | |
pipe._transform = _transform_before | |
if before_list_objs: | |
pipe.set_params(**before_list_objs) | |
delattr(pipe, 'intermediate_results__') | |
delattr(pipe, 'intermediate_keys__') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment