Skip to content

Instantly share code, notes, and snippets.

@angeligareta
Created June 1, 2021 22:12
Show Gist options
  • Save angeligareta/93a0f0a13d0c1de0655ca339a1b7f6d9 to your computer and use it in GitHub Desktop.
Save angeligareta/93a0f0a13d0c1de0655ca339a1b7f6d9 to your computer and use it in GitHub Desktop.
Method to split a Pandas dataset into train, validation and test splits
def get_dataset_partitions_pd(df, train_split=0.8, val_split=0.1, test_split=0.1):
assert (train_split + test_split + val_split) == 1
# Only allows for equal validation and test splits
assert val_split == test_split
# Specify seed to always have the same split distribution between runs
df_sample = df.sample(frac=1, random_state=12)
indices_or_sections = [int(train_split * len(df)), int((1 - val_split - test_split) * len(df))]
train_ds, val_ds, test_ds = np.split(df_sample, indices_or_sections)
return train_ds, val_ds, test_ds
@588chm
Copy link

588chm commented Dec 13, 2023

Hi Angel,
I think there is a little mistake in your code:
Using the default values, the expression [int(train_split * len(df)), int((1 - val_split - test_split) * len(df))]
would result [len(df)*(1-0.8),len(df)*(1-0.1-0.1)], which is [len(df)*0.8,len(df)*0.8], so the val_ds would not have any data in it.
I think the correct line would be:
indices_or_sections = [int(train_split * len(df)), int((1 - test_split) * len(df))]
Please correct me if I'm wrong. Thank you for your work, I came here be following this site.
Cheers, Marcell

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment