Last active
June 19, 2020 14:02
-
-
Save ma7555/121d82d14b785270854b31610faf88d3 to your computer and use it in GitHub Desktop.
Create stratified train/test/validation splits for a pandas dataframe
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from sklearn.model_selection import StratifiedShuffleSplit | |
def get_stf_ttv(data, targets, train_size=0.8, random_state=555): | |
''' | |
Used to get stratified train/test/validation splits | |
Test and validation splits are equal, if train_size is set to 0.8, | |
the remaining 0.2 will be split between test and validation | |
resulting in 80% train, 10% test, 10% validation | |
Parameters: | |
data (pd.DataFrame): | |
targets (pd.Series) | |
train_size (float) | |
random_state (int) | |
Returns: | |
train_index (np.array) | |
test_index (np.array) | |
val_index (np.array) | |
''' | |
sss = StratifiedShuffleSplit(n_splits=1, train_size=train_size, random_state=random_state) | |
train_index, test_valid_index = next(sss.split(data, targets)) | |
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=random_state) | |
test_index, val_index = next(sss.split(data.iloc[test_valid_index], targets.iloc[test_valid_index])) | |
test_index = targets.iloc[test_valid_index].iloc[test_index].index | |
val_index = targets.iloc[test_valid_index].iloc[val_index].index | |
return train_index, test_index, val_index | |
train_index, test_index, val_index = get_stf_ttv(data, targets, random_state=555) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment