Created
June 8, 2018 15:42
-
-
Save mrajchl/3f779bb857dd8daecdac182ad1029fb8 to your computer and use it in GitHub Desktop.
Load data from a TFRecords database and feed into a graph
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def decode(serialized_example): | |
# Decode examples stored in TFRecord | |
# NOTE: make sure to specify the correct dimensions for the images | |
features = tf.parse_single_example( | |
serialized_example, | |
features={'train/image': tf.FixedLenFeature([128, 224, 224, 1], tf.float32), | |
'train/label': tf.FixedLenFeature([], tf.int64)}) | |
# NOTE: No need to cast these features, as they are already `tf.float32` values. | |
return features['train/image'], features['train/label'] | |
dataset = tf.data.TFRecordDataset(train_filename).map(decode) | |
dataset = dataset.repeat(None) | |
dataset = dataset.batch(batch_size) | |
dataset = dataset.prefetch(1) | |
iterator = dataset.make_initializable_iterator() | |
features, labels = iterator.get_next() | |
nx = iterator.get_next() | |
with tf.train.MonitoredTrainingSession() as sess_rec: | |
sess_rec.run(iterator.initializer) | |
for i in range(iterations): | |
try: | |
# Get next features-labels pair | |
rec_batch_feat, rec_batch_lbl = sess_rec.run([features, labels]) | |
except tf.errors.OutOfRangeError: | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment