Skip to content

Instantly share code, notes, and snippets.

@tommct
Created July 2, 2021 17:18
Show Gist options
  • Save tommct/ef7309d7162da649849af43f3cc3fe12 to your computer and use it in GitHub Desktop.
Save tommct/ef7309d7162da649849af43f3cc3fe12 to your computer and use it in GitHub Desktop.
Allows for the retrieval of all or parts of the transformations in a sklearn Pipeline, as well as the ability to dynamically bypass parts of the pipeline.
import contextlib
from functools import partial
from sklearn.pipeline import Pipeline
@contextlib.contextmanager
def intermediate_transforms(pipe: Pipeline, keys: list=[], bypass_list: list=[]):
"""Allows for the retrieval of all or parts of the transformations in a
sklearn Pipeline, as well as the ability to dynamically bypass parts of
the pipeline.
Within the context, intermediate results are available as a dict, `pipe.intermediate_results__`
with keys equal to the names of the transformers in the pipeline. This dict is a temporary structure
and only available within the context.
Args:
pipe (Pipeline): [Pipeline](https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html)
object.
keys (list, optional): List of pipeline object to retrieve. If empty, then
all are available in the pipe's `intermediate_results__` within this
context. Otherwise, just the names listed are captured. Defaults to [].
bypass_list (list, optional): List of names to bypass. When not empty, this is
the set of names in the pipeline to "passthrough" if these names are at
in a pipeline or "drop" if in a FeatureUnion. Defaults to [].
Example:
We instantiate a simple pipeline.
import numpy as np
import pandas as pd
from sklearn.pipeline import Pipeline
from sklearn.base import TransformerMixin
df = pd.DataFrame({'name': ['Anne', 'Bob', 'Charlie', 'Bob'],
'age': [20, 21, 22, 23]})
class LowercaseTransformer(TransformerMixin):
def transform(self, X):
return X.apply(lambda x: x.lower())
class UppercaseTransformer(TransformerMixin):
def transform(self, X):
return X.apply(lambda x: x.upper())
class CamelcaseTransformer(TransformerMixin):
def transform(self, X):
return X.apply(lambda x: x[0].upper() + x[1:].lower())
class ReverseTransformer(TransformerMixin):
def transform(self, X):
try:
return X.applymap(lambda x: x[::-1])
except AttributeError:
return X.apply(lambda x: x[::-1])
lct = LowercaseTransformer()
uct = UppercaseTransformer()
cct = CamelcaseTransformer()
rt = ReverseTransformer()
pipe = Pipeline([('lower', lct), ('upper', uct), ('reverse', rt), ('camel', cct), ('last', 'passthrough')])
Then we can execute some contexts.
# To retrieve all intermediate results...
with intermediate_transforms(pipe):
Xt = pipe.transform(df['name'])
intermediate_results = pipe.intermediate_results__
Outputs:
{'lower': 0 anne
1 bob
2 charlie
3 bob
Name: name, dtype: object,
'upper': 0 ANNE
1 BOB
2 CHARLIE
3 BOB
Name: name, dtype: object,
'reverse': 0 ENNA
1 BOB
2 EILRAHC
3 BOB
Name: name, dtype: object,
'camel': 0 Enna
1 Bob
2 Eilrahc
3 Bob
Name: name, dtype: object}
To retrieve the first few steps, we can execute the following. Note that in this case,
the order of the keys does not matter, but the returned transform, `Xt`, will be the results
of the last transformer in our keys. And `intermediate_results` contains only the keys of interest.
with intermediate_transforms(pipe, keys=['upper', 'lower']):
Xt = pipe.transform(df['name'])
intermediate_results = pipe.intermediate_results__
This provides:
{'lower': 0 anne
1 bob
2 charlie
3 bob
Name: name, dtype: object,
'upper': 0 ANNE
1 BOB
2 CHARLIE
3 BOB
Name: name, dtype: object}
To bypass/passthrough/drop transformers, we can execute this context. This may be useful
in FeautureUnions to ignore some paths. However, skipping transformations is likely to give
unexpected final results.
with intermediate_transforms(pipe, bypass_list=['camel']):
Xt = pipe.transform(df['name'])
intermediate_results = pipe.intermediate_results__
This provides the following output:
{'lower': 0 anne
1 bob
2 charlie
3 bob
Name: name, dtype: object,
'upper': 0 ANNE
1 BOB
2 CHARLIE
3 BOB
Name: name, dtype: object,
'reverse': 0 ENNA
1 BOB
2 EILRAHC
3 BOB
Name: name, dtype: object}
"""
# Our temporary overload of Pipeline._transform() method.
# https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/pipeline.py
def _pipe_transform(self, X):
Xt = X
for _, name, transform in self._iter():
Xt = transform.transform(Xt)
if self.intermediate_keys__ is not None:
try:
self.intermediate_keys__.remove(name)
self.intermediate_results__[name] = Xt
if len(self.intermediate_keys__) == 0:
break
except ValueError:
pass
else:
self.intermediate_results__[name] = Xt
return Xt
if not isinstance(pipe, Pipeline):
raise ValueError(f'"{pipe}" must be a Pipeline.')
pipe.intermediate_results__ = {}
pipe.intermediate_keys__ = None
if keys:
pipe.intermediate_keys__ = keys
before_list_objs = {}
bypass_list_objs = {}
if bypass_list:
params = pipe.get_params()
for k in bypass_list:
if k in params:
before_list_objs[k] = params[k]
if 0 < k.find('__') < len(k):
bypass_list_objs[k] = 'drop' # FeatureUnion bypass
else:
bypass_list_objs[k] = 'passthrough' # Pipeline bypass
if bypass_list_objs:
pipe.set_params(**bypass_list_objs)
_transform_before = pipe._transform
pipe._transform = partial(_pipe_transform, pipe) # Monkey-patch our _pipe_transform method.
yield pipe # Release our patched object to the context
# Restore
pipe._transform = _transform_before
if before_list_objs:
pipe.set_params(**before_list_objs)
delattr(pipe, 'intermediate_results__')
delattr(pipe, 'intermediate_keys__')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment