Skip to content

Instantly share code, notes, and snippets.

@dpoulopoulos
Created August 24, 2021 14:24
Show Gist options
  • Save dpoulopoulos/708c1405525e34747b9b518bd779dc55 to your computer and use it in GitHub Desktop.
Save dpoulopoulos/708c1405525e34747b9b518bd779dc55 to your computer and use it in GitHub Desktop.
train_filenames = tf.io.gfile.glob(f"{tfrecords_dir}/*.tfrec")
batch_size = 32
epochs = 1
steps_per_epoch = 50
AUTOTUNE = tf.data.AUTOTUNE
def prepare_sample(features):
image = tf.image.resize(features["image"], size=(224, 224))
return image, features["category_id"]
def get_dataset(filenames, batch_size):
dataset = (
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
.map(parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
.shuffle(batch_size * 10)
.batch(batch_size)
.prefetch(AUTOTUNE)
)
return dataset
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment