Skip to content

Instantly share code, notes, and snippets.

@ikhlestov
Created March 16, 2017 16:55
Show Gist options
  • Save ikhlestov/e11f042aaf9118d6309f5026e7967fcd to your computer and use it in GitHub Desktop.
Save ikhlestov/e11f042aaf9118d6309f5026e7967fcd to your computer and use it in GitHub Desktop.
Tensorflow multithreading data provider
import tensorflow as tf
class MultithreadedTensorProvider():
""" A class designed to provide tensors input in a
separate threads. """
def __init__(self, capacity, sess, dtypes, shuffle_queue=False,
number_of_threads=1):
"""Initialize a class to provide a tensors with input data.
Args:
capacity: maximum queue size measured in examples.
sess: a tensorflow session.
dtypes: list of data types
shuffle_queue: either to use RandomShuffleQueue or FIFOQueue
"""
self.dtypes = dtypes
self.sess = sess
self.number_of_threads = number_of_threads
if shuffle_queue:
self.queue = tf.RandomShuffleQueue(
dtypes=dtypes,
capacity=capacity)
else:
self.queue = tf.FIFOQueue(
capacity=capacity,
dtypes=dtypes)
self.q_size = self.queue.size()
def get_input(self):
""" Return input tensor """
self.batch = self.queue.dequeue()
return self.batch
def get_queue_size(self):
""" Return how many batch left in the queue """
return self.sess.run(self.q_size)
def set_data_provider(self, data_provider):
""" Set data provider to generate input tensor
Args:
data_provider: a callable to produce a tuple of inputs to be
placed into a queue.
Raises:
TypeError: if data provider is not a callable
"""
if not callable(data_provider):
raise TypeError('Data provider should be a callable.')
data = tf.py_func(data_provider, [], self.dtypes)
enqueue_op = self.queue.enqueue(data)
qr = tf.train.QueueRunner(self.queue, [enqueue_op]*self.number_of_threads)
tf.train.add_queue_runner(qr)
self.coord = tf.train.Coordinator()
self.threads = tf.train.start_queue_runners(sess=self.sess, coord=self.coord)
import warnings
import tensorflow as tf
import numpy as np
from multithreaded_data_provider import MultithreadedTensorProvider
def test_data_provider():
return (np.random.random([1,28*28]).astype(np.float32),
np.random.random([1,10]).astype(np.float32))
def fake_data_provider():
return {'data':np.random.random([5,28*28]).astype(np.float32),
'labels':np.random.random([5,10]).astype(np.float32),
'key1':None,
'key2':None}
def dp():
batch = fake_data_provider()
#some preprocessing
for i in range(4):
np.random.random([100,100])
return (batch['data'], batch['labels'])
sess = tf.Session()
tensor_provider = MultithreadedTensorProvider(capacity=10, sess=sess,
dtypes=[tf.float32, tf.float32], shuffle_queue=False, number_of_threads=3)
images_batch, labels_batch = tensor_provider.get_input()
w = tf.get_variable("w1", [28*28, 10])
y_pred = tf.matmul(images_batch, w)
loss = tf.nn.softmax_cross_entropy_with_logits(logits=y_pred, labels=labels_batch)
loss_mean = tf.reduce_mean(loss)
train_op = tf.train.AdamOptimizer(0.005).minimize(loss)
init = tf.global_variables_initializer()
sess.run(init)
print('!!!!!!!!!!!!!!!!!!!!!!!!!!start train')
tensor_provider.set_data_provider(dp)
for i in range(100):
print()
print('q_size', tensor_provider.get_queue_size())
print(sess.run(labels_batch).shape)
print('q_size',tensor_provider.get_queue_size())
_, loss_val = sess.run([train_op, loss_mean])
if tensor_provider.get_queue_size() == 0:
warnings.warn("Queue is empty!\
Try to increase queue capasity and/or number_of_threads")
print('!!!!!!!!!!!!!!!!!!!!!!!Train finish')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment