Created
March 13, 2018 13:12
-
-
Save haoliplus/8f4c7044a6ce161db1b69b23f8326dc7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
# In[13]: | |
import tensorflow as tf | |
import numpy as np | |
def make_record(item_list, rating_list, label): | |
# The object we return | |
ex = tf.train.SequenceExample() | |
# non-sequential feature of our example | |
item_len = len(item_list) | |
ex.context.feature["item_len"].int64_list.value.append(item_len) | |
ex.context.feature["labels"].int64_list.value.append(label) | |
# Feature lists for our example | |
fl_items = ex.feature_lists.feature_list["items"] | |
fl_ratings = ex.feature_lists.feature_list["ratings"] | |
for item in item_list: | |
fl_items.feature.add().int64_list.value.append(item) | |
for rating in rating_list: | |
fl_ratings.feature.add().int64_list.value.append(rating) | |
return ex | |
def _save(file_name, items, ratings): | |
with open(file_name, 'w') as fp: | |
writer = tf.python_io.TFRecordWriter(fp.name) | |
for item_list, rating_list in zip(items, ratings): | |
label = item_list[-1] | |
ex = make_record(item_list, rating_list, label) | |
writer.write(ex.SerializeToString()) | |
writer.close() | |
_save('test.tfrecord', items=[[1,2,3,4], [1,2,3],[1,2,3,4], [1,2,3]], ratings=[[5,5,5,5], [6,6,6],[5,5,5,5], [6,6,6]]) | |
# In[14]: | |
def batch_with_dynamic_pad(sequence_parsed, context_parsed, tensor_name_to_mask, batch_size, queue_capacity): | |
def _get_2d_mask(tensor): | |
length = tf.shape(tensor)[0] | |
tmp = tf.expand_dims(length, 0) | |
return tf.ones(tmp, dtype=tf.int32) | |
padding_list = [] | |
sequence_parsed_names = list(sequence_parsed.keys()) | |
context_parsed_names = list(context_parsed.keys()) | |
tensor_names = list(tensor_name_to_mask.keys()) | |
padding_list.extend([sequence_parsed[key] for key in sequence_parsed_names]) | |
padding_list.extend([context_parsed[key] for key in context_parsed_names]) | |
padding_list.extend([_get_2d_mask(sequence_parsed[name]) for name in tensor_names ]) | |
name_list = sequence_parsed_names + context_parsed_names + [tensor_name_to_mask[name] for name in tensor_names] | |
enqueue_list = [padding_list] | |
ret_list = tf.train.batch_join( | |
enqueue_list, | |
batch_size=batch_size, | |
capacity=queue_capacity, | |
dynamic_pad=True, | |
name="batch_and_pad") | |
ret = {k:v for k,v in zip(name_list, ret_list)} | |
return ret | |
def read_tfrecord(tfrecord_path, batch_size): | |
tf.logging.info('Read tfrecord {} batch_size {}'.format(tfrecord_path, batch_size)) | |
filename_queue = tf.train.string_input_producer([tfrecord_path], num_epochs=None) | |
reader = tf.TFRecordReader() | |
ex = reader.read(filename_queue) | |
context_feature_names = ['labels'] | |
sequence_feature_names = ['items', 'ratings'] | |
tensor_name_to_mask = {'items':'item_mask'} | |
context_features = { k:tf.FixedLenFeature([], dtype=tf.int64) for k in context_feature_names} | |
sequence_features = { k:tf.FixedLenSequenceFeature([], dtype=tf.int64) for k in sequence_feature_names} | |
context_parsed, sequence_parsed = tf.parse_single_sequence_example( | |
serialized=ex.value, | |
context_features=context_features, | |
sequence_features=sequence_features | |
) | |
ret = batch_with_dynamic_pad(sequence_parsed, context_parsed, tensor_name_to_mask, batch_size=batch_size, queue_capacity=200) | |
return ret | |
tensors = read_tfrecord('test.tfrecord', 2) | |
items = tensors['items'] | |
labels = tensors['labels'] | |
ratings = tensors['ratings'] | |
# In[22]: | |
init_op = tf.group(tf.global_variables_initializer(), | |
tf.local_variables_initializer()) | |
with tf.Session() as sess: | |
sess.run(init_op) | |
coord = tf.train.Coordinator() | |
threads = tf.train.start_queue_runners(coord=coord) | |
for _ in range(3): | |
i, l, r = sess.run([items, labels, ratings]) | |
print(i) | |
print(l) | |
print(r) | |
print("---------") | |
# Let's read off 3 batches | |
# In[ ]: | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment