Created
December 21, 2017 06:24
-
-
Save formigone/439a05d878222e104a17856b46dfbdf5 to your computer and use it in GitHub Desktop.
Creating and using TensorFlow TFRecords
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def list_to_tfrecord(list, tfrecord_filename): | |
""" | |
Convert a list of (features, labels) to a TFRecord file. | |
param list: a list of tuples with (feature, label) | |
""" | |
with python_io.TFRecordWriter(tfrecord_filename) as writer: | |
for feature, label in list: | |
example = tf.train.Example() | |
example.features.feature['x'].float_list.value.extend(features) | |
example.features.feature['y'].int64_list.value.append(label) | |
writer.write(example.SerializeToString()) | |
def gen_input_fn(tfrecord, epochs=1, batch_size=16, buffer_size=64, feature_shape=(299 * 299,), label_shape=()): | |
""" | |
Return an input_fn that uses TFRecords for use with TensorFlow's estimator API | |
""" | |
def parse(example): | |
features = { | |
'x': tf.FixedLenFeature(feature_shape, tf.float32), | |
'y': tf.FixedLenFeature(label_shape, tf.int64), | |
} | |
parsed_features = tf.parse_single_example(example_proto, features) | |
return parsed_features['x'], parsed_features['y'] | |
def input_fn(): | |
dataset = tf.contrib.data.TFRecordDataset(['train_aug_12.tfrecords']) | |
dataset = dataset.map(parse) | |
if buffer_size > 0: | |
dataset = dataset.shuffle(buffer_size) | |
dataset = dataset.repeat(epochs) | |
dataset = dataset.batch(batch_size) | |
features, label = dataset.make_one_shot_iterator().get_next() | |
return features, label | |
return input_fn |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
There are some errors in your code like python_io,example_proto,...
Also, please add a readme file to understand how to use your python file