Skip to content

Instantly share code, notes, and snippets.

@dsalaj
Created March 23, 2020 18:05
Show Gist options
  • Save dsalaj/c29bc12c0c1b0c70b638c5d40fcbaac8 to your computer and use it in GitHub Desktop.
Save dsalaj/c29bc12c0c1b0c70b638c5d40fcbaac8 to your computer and use it in GitHub Desktop.
Different ways of splitting tensorflow dataset
def split_dataset(ds, version=1):
if version == 1:
train_ds = ds.dataset.shard(num_shards=4, index=0)
train_ds.concatenate(ds.dataset.shard(num_shards=4, index=1))
train_ds.concatenate(ds.dataset.shard(num_shards=4, index=2))
valid_ds = ds.dataset.shard(num_shards=4, index=3)
return train_ds, valid_ds
elif version == 2:
def is_val(x, y):
return x % 20 == 0
def is_train(x, y):
return not is_val(x, y)
recover = lambda x, y: y
train_ds = ds.dataset.enumerate().filter(is_train).map(recover)
valid_ds = ds.dataset.enumerate().filter(is_val).map(recover)
return train_ds, valid_ds
elif version == 3:
train_ds = ds.dataset.take(ds.n_samples - 1000)
valid_ds = ds.dataset.skip(ds.n_samples - 1000)
return train_ds, valid_ds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment