Skip to content

Instantly share code, notes, and snippets.

@cozek
Created March 24, 2020 09:16
Show Gist options
  • Save cozek/519df210dd59dc9234ee67c68040e6af to your computer and use it in GitHub Desktop.
Save cozek/519df210dd59dc9234ee67c68040e6af to your computer and use it in GitHub Desktop.
Split dataframe into train and validation
def split_dataframe(df:pd.DataFrame, train_frac:float, shuffle: bool ):
"""
Splits DataFrame into train and val
Args:
df: DataFrame to split, note: indexes will be reset
train_frac: fraction to use for training
shuffle: Shuffles df if true
Returns:
split_df: DataFrame with splits mentioned in 'split' column
"""
assert train_frac <= 1.0
if train_frac == 1.0:
df.split == 'train'
return df
df.index = range(len(df.index)) #resetting index
df = df.copy()
if shuffle:
df = df.sample(frac=1).sample(frac=1)
val_frac = 1 - train_frac
assert val_frac + train_frac == 1.0
split_df = None
labels = set(df.label)
assert len(labels)!= 1
for lbl in labels:
temp_df = df[df.label == lbl]
_train_df = temp_df.sample(frac=train_frac)
_train_df['split'] = 'train'
_val_df = temp_df[~temp_df.index.isin(_train_df.index)].copy()
_val_df['split'] = 'val'
if split_df is None:
split_df = pd.concat([_train_df,_val_df])
else:
split_df = pd.concat([split_df,_train_df,_val_df])
#test that the the splits add up
assert sum(df.label.value_counts()) == \
sum(split_df[split_df.split == 'train'].label.value_counts())\
+ sum(split_df[split_df.split == 'val'].label.value_counts())
return split_df
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment