Created
June 1, 2021 22:12
-
-
Save angeligareta/93a0f0a13d0c1de0655ca339a1b7f6d9 to your computer and use it in GitHub Desktop.
Method to split a Pandas dataset into train, validation and test splits
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_dataset_partitions_pd(df, train_split=0.8, val_split=0.1, test_split=0.1): | |
assert (train_split + test_split + val_split) == 1 | |
# Only allows for equal validation and test splits | |
assert val_split == test_split | |
# Specify seed to always have the same split distribution between runs | |
df_sample = df.sample(frac=1, random_state=12) | |
indices_or_sections = [int(train_split * len(df)), int((1 - val_split - test_split) * len(df))] | |
train_ds, val_ds, test_ds = np.split(df_sample, indices_or_sections) | |
return train_ds, val_ds, test_ds |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi Angel,
I think there is a little mistake in your code:
Using the default values, the expression
[int(train_split * len(df)), int((1 - val_split - test_split) * len(df))]
would result
[len(df)*(1-0.8),len(df)*(1-0.1-0.1)]
, which is[len(df)*0.8,len(df)*0.8]
, so the val_ds would not have any data in it.I think the correct line would be:
indices_or_sections = [int(train_split * len(df)), int((1 - test_split) * len(df))]
Please correct me if I'm wrong. Thank you for your work, I came here be following this site.
Cheers, Marcell