Last active
October 11, 2018 15:23
-
-
Save rsepassi/68a443713ce58d07055dd3a76bb19cc9 to your computer and use it in GitHub Desktop.
tensorflow/datasets
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import tensorflow_datasets as tfds | |
# tfds works with Eager and Graph modes | |
tf.enable_eager_execution() | |
# 0. Select the dataset you'd like to use | |
print(tfds.list_builders()) | |
# 1. Construct the DatasetBuilder | |
# Each dataset is implemented as a DatasetBuilder and can be fetched by | |
# string name. | |
mnist_builder = tfds.builder(name="mnist", data_dir="~/tfds/data") | |
# 2. Download and prepare the dataset into a format ready for a tf.data pipeline | |
mnist_builder.download_and_prepare() | |
# 3. Build a tf.data.Dataset from the prepared data | |
train_dataset = mnist_builder.as_dataset(split=tfds.Split.TRAIN) | |
# 4. Build the rest of your input pipeline using the tf.data API | |
train_dataset = train_dataset.repeat().shuffle(1024).batch(32).prefetch(100) | |
# If we looked at a single batch, it has a features dictionary with keys | |
# "input" and "target" | |
features, = train_dataset.take(1) | |
images, labels = features["input"], features["target"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment