Last active
May 9, 2021 04:01
-
-
Save wassname/13af904e117fdec775446fedb559c57d to your computer and use it in GitHub Desktop.
split_by_unique_col
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from sklearn.model_selection import train_test_split | |
| import pandas as pd | |
| def shuffle_df(df, random_seed=42): | |
| return df.sample(frac=1, random_state=random_seed, replace=False) | |
| def split_by_unique_col(df, col='patient_id', stratify_cols=[], random_seed=42): | |
| """ | |
| Make a dataframe of unique ids, with our stratification data | |
| url: https://gist.github.com/wassname/13af904e117fdec775446fedb559c57d | |
| """ | |
| df_ids = df[[col]+stratify_cols].groupby(col).first() | |
| # split up the unique ids, stratifying | |
| df_ids_train, df_ids_other = train_test_split(df_ids, test_size=0.4, random_state=random_seed, stratify=df_ids[stratify_cols] if len(stratify_cols) else None) | |
| df_ids_vals, df_ids_test = train_test_split(df_ids_other, test_size=0.5, random_state=random_seed, stratify=df_ids_other[stratify_cols] if len(stratify_cols) else None) | |
| train = df[df[col].isin(df_ids_train.index)] | |
| valid = df[df[col].isin(df_ids_vals.index)] | |
| test = df[df[col].isin(df_ids_test.index)] | |
| # make sure there is no overlap | |
| assert not set(train[col]).intersection(set(test[col])) | |
| assert not set(train[col]).intersection(set(valid[col])) | |
| assert not set(test[col]).intersection(set(valid[col])) | |
| train = shuffle_df(train, random_seed=random_seed) | |
| test = shuffle_df(test, random_seed=random_seed) | |
| valid = shuffle_df(valid, random_seed=random_seed) | |
| return train, valid, test |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment