Skip to content

Instantly share code, notes, and snippets.

@djsegal
Created December 30, 2019 04:16
Show Gist options
  • Select an option

  • Save djsegal/237d96c0eef26002c4aaf6c6ef957374 to your computer and use it in GitHub Desktop.

Select an option

Save djsegal/237d96c0eef26002c4aaf6c6ef957374 to your computer and use it in GitHub Desktop.
from sklearn.base import BaseEstimator, TransformerMixin
class PassThroughTransformer(BaseEstimator, TransformerMixin):
def __init__(self):
self.input_features = None
def fit(self, X, y=None):
assert self.input_features is None
if type(X) == np.ndarray :
self.input_features = []
else:
self.input_features = X.columns.tolist()
return self
def transform(self, X):
assert self.input_features is not None
return X
def get_feature_names(self, input_features=None):
assert type(self.input_features) == list
if input_features is not None:
assert self.input_features == list(input_features)
return self.input_features
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment