Created
December 29, 2019 04:25
-
-
Save djsegal/3db97acec43ccf75d6313db4aa5af15c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
from sklearn.base import BaseEstimator, TransformerMixin | |
class ReciprocalFeatures(BaseEstimator, TransformerMixin): | |
def __init__(self): | |
self.input_features = None | |
self.rename_lambda = \ | |
lambda input_feature: f"_inv_{input_feature}" | |
def fit(self, X, y=None): | |
assert self.input_features == None | |
self.input_features = X.columns.tolist() | |
self.input_features.extend(list( | |
map(self.rename_lambda, self.input_features) | |
)) | |
return self | |
def transform(self, X): | |
assert self.input_features != None | |
assert np.all(X != 0) | |
reciprocal_X = 1 / X | |
reciprocal_X.columns = map( | |
self.rename_lambda, reciprocal_X.columns | |
) | |
return pd.concat([X, reciprocal_X], axis=1) | |
def get_feature_names(self): | |
assert self.input_features != None | |
return self.input_features |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment