Skip to content

Instantly share code, notes, and snippets.

@joyhuang9473
Created January 17, 2018 06:51
Show Gist options
  • Save joyhuang9473/b5aecd403bf18ccd77c2cee4a186fdde to your computer and use it in GitHub Desktop.
Save joyhuang9473/b5aecd403bf18ccd77c2cee4a186fdde to your computer and use it in GitHub Desktop.
import tensorflow as tf
import cv2
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def int64_list_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def create_tf_example(filename, image_data, label):
(height, width, channel) = image_data.shape
image_data = image_data.reshape((height*width*channel, 1))
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': int64_feature(height),
'image/width': int64_feature(width),
'image/channel': int64_feature(channel),
'image/filename': bytes_feature(filename),
'image/encoded': int64_list_feature(image_data),
'image/class/label': int64_feature(label),
}))
return tf_example
def parse_tf_example(serialized_example):
example = tf.parse_single_example(serialized_example, features={
'image/height': tf.FixedLenFeature([], tf.int64),
'image/width': tf.FixedLenFeature([], tf.int64),
'image/channel': tf.FixedLenFeature([], tf.int64),
'image/filename' : tf.FixedLenFeature([], tf.string),
'image/encoded' : tf.VarLenFeature(tf.int64),
'image/class/label' : tf.FixedLenFeature([], tf.int64),
})
height = tf.cast(example['image/height'], tf.int64)
width = tf.cast(example['image/width'], tf.int64)
channel = tf.cast(example['image/channel'], tf.int64)
image = tf.sparse_tensor_to_dense(example['image/encoded'], default_value=0)
image = tf.reshape(image, tf.stack([height, width, channel]))
label = tf.cast(example['image/class/label'], tf.int64)
return { 'image': image, 'label': label }
if __name__ == '__main__':
'''
Create tfrecord
'''
filename = 'test.jpg'
image_data = cv2.imread(filename, 1)
label = 1
# write img ndarray in tfrecord
writer = tf.python_io.TFRecordWriter('test.tfrecord')
example = create_tf_example(filename, image_data, label)
writer.write(example.SerializeToString())
writer.close()
'''
Read tfrecord
'''
dataset = tf.data.TFRecordDataset('test.tfrecord')
dataset = dataset.map(parse_tf_example)
iterator = dataset.make_initializable_iterator()
data = iterator.get_next()
with tf.Session() as sess:
sess.run(iterator.initializer)
example = sess.run(data)
print(example['image'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment