Created
December 30, 2019 04:16
-
-
Save djsegal/237d96c0eef26002c4aaf6c6ef957374 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from sklearn.base import BaseEstimator, TransformerMixin | |
| class PassThroughTransformer(BaseEstimator, TransformerMixin): | |
| def __init__(self): | |
| self.input_features = None | |
| def fit(self, X, y=None): | |
| assert self.input_features is None | |
| if type(X) == np.ndarray : | |
| self.input_features = [] | |
| else: | |
| self.input_features = X.columns.tolist() | |
| return self | |
| def transform(self, X): | |
| assert self.input_features is not None | |
| return X | |
| def get_feature_names(self, input_features=None): | |
| assert type(self.input_features) == list | |
| if input_features is not None: | |
| assert self.input_features == list(input_features) | |
| return self.input_features |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment