Skip to content

Instantly share code, notes, and snippets.

@bodokaiser
Created October 20, 2017 09:50
Show Gist options
  • Save bodokaiser/91a450af64c1dc2160c66b870eb629d6 to your computer and use it in GitHub Desktop.
Save bodokaiser/91a450af64c1dc2160c66b870eb629d6 to your computer and use it in GitHub Desktop.
import numpy as np
import tensorflow as tf
SIZE = int(64e6)
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def decode(example):
feature_keys = {'encoded': tf.FixedLenFeature((), tf.string)}
features = tf.parse_single_example(example, features=feature_keys)
return tf.decode_raw(features['encoded'], tf.int32)
def encode(array):
return tf.train.Example(features=tf.train.Features(feature={
'encoded': _bytes_feature([array.astype(np.int32).tobytes()]),
})).SerializeToString()
with tf.python_io.TFRecordWriter('foo.tfrecord') as writer:
writer.write(encode(np.ones([SIZE], np.int32)))
dataset = tf.data.TFRecordDataset('foo.tfrecord').map(decode)
data = dataset.make_one_shot_iterator().get_next()
with tf.Session() as session:
print(session.run(data).shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment