Skip to content

Instantly share code, notes, and snippets.

@securetorobert
Created September 28, 2019 17:23
Show Gist options
  • Save securetorobert/f18c18af38e968c139dc20b93805dd37 to your computer and use it in GitHub Desktop.
Save securetorobert/f18c18af38e968c139dc20b93805dd37 to your computer and use it in GitHub Desktop.
Create Image input pipeline
def preprocess_image(image):
image = tf.image.decode_jpeg(image, channels=NUM_CHANNELS)
image = tf.image.resize(image, [HEIGHT, WIDTH])
image /= 255.0 # normalize to [0,1] range
return image
def load_and_preprocess_image(path):
image = tf.io.read_file(path)
return preprocess_image(image)
path_ds = tf.data.Dataset.from_tensor_slices(files)
image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(categories, tf.int64))
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
ds = image_label_ds.shuffle(buffer_size=1000 * BATCH_SIZE)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)
# `prefetch` lets the dataset fetch batches, in the background while the model is training.
ds = ds.prefetch(buffer_size=AUTOTUNE)
ds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment