Skip to content

Instantly share code, notes, and snippets.

@mrajchl
Created June 8, 2018 15:42
Show Gist options
  • Save mrajchl/3f779bb857dd8daecdac182ad1029fb8 to your computer and use it in GitHub Desktop.
Save mrajchl/3f779bb857dd8daecdac182ad1029fb8 to your computer and use it in GitHub Desktop.
Load data from a TFRecords database and feed into a graph
def decode(serialized_example):
# Decode examples stored in TFRecord
# NOTE: make sure to specify the correct dimensions for the images
features = tf.parse_single_example(
serialized_example,
features={'train/image': tf.FixedLenFeature([128, 224, 224, 1], tf.float32),
'train/label': tf.FixedLenFeature([], tf.int64)})
# NOTE: No need to cast these features, as they are already `tf.float32` values.
return features['train/image'], features['train/label']
dataset = tf.data.TFRecordDataset(train_filename).map(decode)
dataset = dataset.repeat(None)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(1)
iterator = dataset.make_initializable_iterator()
features, labels = iterator.get_next()
nx = iterator.get_next()
with tf.train.MonitoredTrainingSession() as sess_rec:
sess_rec.run(iterator.initializer)
for i in range(iterations):
try:
# Get next features-labels pair
rec_batch_feat, rec_batch_lbl = sess_rec.run([features, labels])
except tf.errors.OutOfRangeError:
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment