Skip to content

Instantly share code, notes, and snippets.

@himanshurawlani
Last active March 21, 2019 22:09
Show Gist options
  • Save himanshurawlani/95ec9abbcf0c1382318d6162a42e2d8d to your computer and use it in GitHub Desktop.
Save himanshurawlani/95ec9abbcf0c1382318d6162a42e2d8d to your computer and use it in GitHub Desktop.
This gist shows how to fetch batches of data downloaded using tfds.load()
plt.figure(figsize=(12,12))
for batch in train.take(1):
for i in range(9):
image, label = batch[0][i], batch[1][i]
plt.subplot(3, 3, i+1)
plt.imshow(image.numpy())
plt.title(get_label_name(label.numpy()))
plt.grid(False)
# OR
for batch in tfds.as_numpy(train):
for i in range(9):
image, label = batch[0][i], batch[1][i]
plt.subplot(3, 3, i+1)
plt.imshow(image)
plt.title(get_label_name(label))
plt.grid(False)
# We need to break the loop else the outer loop
# will loop over all the batches in the training set
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment