Skip to content

Instantly share code, notes, and snippets.

@MLWhiz
Last active July 28, 2019 18:46
Show Gist options
  • Save MLWhiz/c6e91fb6e50a980bf2101885afbc6749 to your computer and use it in GitHub Desktop.
Save MLWhiz/c6e91fb6e50a980bf2101885afbc6749 to your computer and use it in GitHub Desktop.
# taken from https://medium.com/@pouryaayria/k-fold-target-encoding-dfe9a594874b
from sklearn import base
from sklearn.model_selection import KFold
class KFoldTargetEncoderTrain(base.BaseEstimator,
base.TransformerMixin):
def __init__(self,colnames,targetName,
n_fold=5, verbosity=True,
discardOriginal_col=False):
self.colnames = colnames
self.targetName = targetName
self.n_fold = n_fold
self.verbosity = verbosity
self.discardOriginal_col = discardOriginal_col
def fit(self, X, y=None):
return self
def transform(self,X):
assert(type(self.targetName) == str)
assert(type(self.colnames) == str)
assert(self.colnames in X.columns)
assert(self.targetName in X.columns)
mean_of_target = X[self.targetName].mean()
kf = KFold(n_splits = self.n_fold,
shuffle = True, random_state=2019)
col_mean_name = self.colnames + '_' + 'Kfold_Target_Enc'
X[col_mean_name] = np.nan
for tr_ind, val_ind in kf.split(X):
X_tr, X_val = X.iloc[tr_ind], X.iloc[val_ind]
X.loc[X.index[val_ind], col_mean_name] = X_val[self.colnames].map(X_tr.groupby(self.colnames)
[self.targetName].mean())
X[col_mean_name].fillna(mean_of_target, inplace = True)
if self.verbosity:
encoded_feature = X[col_mean_name].values
print('Correlation between the new feature, {} and, {} is {}.'.format(col_mean_name,self.targetName,
np.corrcoef(X[self.targetName].values,
encoded_feature)[0][1]))
if self.discardOriginal_col:
X = X.drop(self.targetName, axis=1)
return X
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment