Forked from angeligareta/get_dataset_partitions_tf.py
Last active
May 14, 2022 18:26
-
-
Save markub3327/3ab35c8039d4c9c9eb932943aa36b24f to your computer and use it in GitHub Desktop.
Method to split a tensorflow dataset (tf.data.Dataset) into train, validation and test splits
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_dataset_partitions_tf(ds, ds_size, train_split=0.7, val_split=0.15, test_split=0.15, shuffle=True, shuffle_size=10000, batch_size=32): | |
assert (train_split + test_split + val_split) == 1 | |
if shuffle: | |
# Specify seed to always have the same split distribution between runs | |
ds = ds.shuffle(shuffle_size, seed=12, reshuffle_each_iteration=False) | |
train_size = int(train_split * ds_size) | |
val_size = int(val_split * ds_size) | |
train_ds = ds.take(train_size).shuffle(buffer_size=batch_size * 8).batch(batch_size) | |
val_ds = ds.skip(train_size).take(val_size).shuffle(buffer_size=batch_size * 8).batch(batch_size) | |
test_ds = ds.skip(train_size).skip(val_size).shuffle(buffer_size=batch_size * 8).batch(batch_size) | |
return train_ds, val_ds, test_ds |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@angeligareta
What do you think about this solution?