Created
March 28, 2019 02:48
-
-
Save yu-iskw/a75c6a04945f4b8da951edab23b48305 to your computer and use it in GitHub Desktop.
TFRecord for list of list
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _int64_feature(value): | |
# value must be a numpy array. | |
return tf.train.Feature(int64_list=tf.train.Int64List(value=value.flatten())) | |
# Write an array to TFrecord. | |
# a is an array which contains lists of variant length. | |
a = np.array([[0, 54, 91, 153, 177], | |
[0, 50, 89, 147, 196], | |
[0, 38, 79, 157], | |
[0, 49, 89, 147, 177], | |
[0, 32, 73, 145]]) | |
writer = tf.python_io.TFRecordWriter('file') | |
for i in range(a.shape[0]): # i = 0 ~ 4 | |
x_train = np.array(a[i]) | |
feature = {'i' : _int64_feature(np.array([i])), | |
'data': _int64_feature(x_train)} | |
# Create an example protocol buffer | |
example = tf.train.Example(features=tf.train.Features(feature=feature)) | |
# Serialize to string and write on the file | |
writer.write(example.SerializeToString()) | |
writer.close() | |
# Check TFRocord file. | |
record_iterator = tf.python_io.tf_record_iterator(path='file') | |
for string_record in record_iterator: | |
example = tf.train.Example() | |
example.ParseFromString(string_record) | |
i = (example.features.feature['i'].int64_list.value) | |
data = (example.features.feature['data'].int64_list.value) | |
print(i, data) | |
# Use Dataset API to read the TFRecord file. | |
filenames = ["file"] | |
dataset = tf.data.TFRecordDataset(filenames) | |
def _parse_function(example_proto): | |
keys_to_features = {'i':tf.VarLenFeature(tf.int64), | |
'data':tf.VarLenFeature(tf.int64)} | |
parsed_features = tf.parse_single_example(example_proto, keys_to_features) | |
return tf.sparse_tensor_to_dense(parsed_features['i']), \ | |
tf.sparse_tensor_to_dense(parsed_features['data']) | |
# Parse the record into tensors. | |
dataset = dataset.map(_parse_function) | |
# Shuffle the dataset | |
dataset = dataset.shuffle(buffer_size=1) | |
# Repeat the input indefinitly | |
dataset = dataset.repeat() | |
# Generate batches | |
dataset = dataset.batch(1) | |
# Create a one-shot iterator | |
iterator = dataset.make_one_shot_iterator() | |
i, data = iterator.get_next() | |
with tf.Session() as sess: | |
print(sess.run([i, data])) | |
print(sess.run([i, data])) | |
print(sess.run([i, data])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment