Last active
June 30, 2019 22:19
-
-
Save groverpr/bb88d26d3133b3e9eff7c4d1303e8252 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from sklearn.model_selection import KFold | |
| def target_encoder_kfold(train_data, test_data, cols_encode, target, folds=10): | |
| """ | |
| Mean regularized target encoding based on kfold | |
| """ | |
| kf = KFold(n_splits=folds, random_state=1) | |
| for col in cols_encode: | |
| global_mean = train_data[target].mean() | |
| for train_index, test_index in kf.split(train_data): | |
| mean_target = train_data.iloc[train_index].groupby(col)[target].mean() | |
| train_data.loc[test_index, col + "_mean_enc"] = train_data.loc[test_index, col].map(mean_target) | |
| train_data[col + "_mean_enc"].fillna(global_mean, inplace=True) | |
| # making test encoding using full training data | |
| col_mean = train_data.groupby(col)[target].mean() | |
| test_data[col + "_mean_enc"] = test_data[col].map(col_mean) | |
| test_data[col + "_mean_enc"].fillna(global_mean, inplace=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment