Skip to content

Instantly share code, notes, and snippets.

@haoliplus
Created March 13, 2018 13:12
Show Gist options
  • Save haoliplus/8f4c7044a6ce161db1b69b23f8326dc7 to your computer and use it in GitHub Desktop.
Save haoliplus/8f4c7044a6ce161db1b69b23f8326dc7 to your computer and use it in GitHub Desktop.
# coding: utf-8
# In[13]:
import tensorflow as tf
import numpy as np
def make_record(item_list, rating_list, label):
# The object we return
ex = tf.train.SequenceExample()
# non-sequential feature of our example
item_len = len(item_list)
ex.context.feature["item_len"].int64_list.value.append(item_len)
ex.context.feature["labels"].int64_list.value.append(label)
# Feature lists for our example
fl_items = ex.feature_lists.feature_list["items"]
fl_ratings = ex.feature_lists.feature_list["ratings"]
for item in item_list:
fl_items.feature.add().int64_list.value.append(item)
for rating in rating_list:
fl_ratings.feature.add().int64_list.value.append(rating)
return ex
def _save(file_name, items, ratings):
with open(file_name, 'w') as fp:
writer = tf.python_io.TFRecordWriter(fp.name)
for item_list, rating_list in zip(items, ratings):
label = item_list[-1]
ex = make_record(item_list, rating_list, label)
writer.write(ex.SerializeToString())
writer.close()
_save('test.tfrecord', items=[[1,2,3,4], [1,2,3],[1,2,3,4], [1,2,3]], ratings=[[5,5,5,5], [6,6,6],[5,5,5,5], [6,6,6]])
# In[14]:
def batch_with_dynamic_pad(sequence_parsed, context_parsed, tensor_name_to_mask, batch_size, queue_capacity):
def _get_2d_mask(tensor):
length = tf.shape(tensor)[0]
tmp = tf.expand_dims(length, 0)
return tf.ones(tmp, dtype=tf.int32)
padding_list = []
sequence_parsed_names = list(sequence_parsed.keys())
context_parsed_names = list(context_parsed.keys())
tensor_names = list(tensor_name_to_mask.keys())
padding_list.extend([sequence_parsed[key] for key in sequence_parsed_names])
padding_list.extend([context_parsed[key] for key in context_parsed_names])
padding_list.extend([_get_2d_mask(sequence_parsed[name]) for name in tensor_names ])
name_list = sequence_parsed_names + context_parsed_names + [tensor_name_to_mask[name] for name in tensor_names]
enqueue_list = [padding_list]
ret_list = tf.train.batch_join(
enqueue_list,
batch_size=batch_size,
capacity=queue_capacity,
dynamic_pad=True,
name="batch_and_pad")
ret = {k:v for k,v in zip(name_list, ret_list)}
return ret
def read_tfrecord(tfrecord_path, batch_size):
tf.logging.info('Read tfrecord {} batch_size {}'.format(tfrecord_path, batch_size))
filename_queue = tf.train.string_input_producer([tfrecord_path], num_epochs=None)
reader = tf.TFRecordReader()
ex = reader.read(filename_queue)
context_feature_names = ['labels']
sequence_feature_names = ['items', 'ratings']
tensor_name_to_mask = {'items':'item_mask'}
context_features = { k:tf.FixedLenFeature([], dtype=tf.int64) for k in context_feature_names}
sequence_features = { k:tf.FixedLenSequenceFeature([], dtype=tf.int64) for k in sequence_feature_names}
context_parsed, sequence_parsed = tf.parse_single_sequence_example(
serialized=ex.value,
context_features=context_features,
sequence_features=sequence_features
)
ret = batch_with_dynamic_pad(sequence_parsed, context_parsed, tensor_name_to_mask, batch_size=batch_size, queue_capacity=200)
return ret
tensors = read_tfrecord('test.tfrecord', 2)
items = tensors['items']
labels = tensors['labels']
ratings = tensors['ratings']
# In[22]:
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for _ in range(3):
i, l, r = sess.run([items, labels, ratings])
print(i)
print(l)
print(r)
print("---------")
# Let's read off 3 batches
# In[ ]:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment