Created
March 16, 2017 16:55
-
-
Save ikhlestov/e11f042aaf9118d6309f5026e7967fcd to your computer and use it in GitHub Desktop.
Tensorflow multithreading data provider
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
class MultithreadedTensorProvider(): | |
""" A class designed to provide tensors input in a | |
separate threads. """ | |
def __init__(self, capacity, sess, dtypes, shuffle_queue=False, | |
number_of_threads=1): | |
"""Initialize a class to provide a tensors with input data. | |
Args: | |
capacity: maximum queue size measured in examples. | |
sess: a tensorflow session. | |
dtypes: list of data types | |
shuffle_queue: either to use RandomShuffleQueue or FIFOQueue | |
""" | |
self.dtypes = dtypes | |
self.sess = sess | |
self.number_of_threads = number_of_threads | |
if shuffle_queue: | |
self.queue = tf.RandomShuffleQueue( | |
dtypes=dtypes, | |
capacity=capacity) | |
else: | |
self.queue = tf.FIFOQueue( | |
capacity=capacity, | |
dtypes=dtypes) | |
self.q_size = self.queue.size() | |
def get_input(self): | |
""" Return input tensor """ | |
self.batch = self.queue.dequeue() | |
return self.batch | |
def get_queue_size(self): | |
""" Return how many batch left in the queue """ | |
return self.sess.run(self.q_size) | |
def set_data_provider(self, data_provider): | |
""" Set data provider to generate input tensor | |
Args: | |
data_provider: a callable to produce a tuple of inputs to be | |
placed into a queue. | |
Raises: | |
TypeError: if data provider is not a callable | |
""" | |
if not callable(data_provider): | |
raise TypeError('Data provider should be a callable.') | |
data = tf.py_func(data_provider, [], self.dtypes) | |
enqueue_op = self.queue.enqueue(data) | |
qr = tf.train.QueueRunner(self.queue, [enqueue_op]*self.number_of_threads) | |
tf.train.add_queue_runner(qr) | |
self.coord = tf.train.Coordinator() | |
self.threads = tf.train.start_queue_runners(sess=self.sess, coord=self.coord) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import warnings | |
import tensorflow as tf | |
import numpy as np | |
from multithreaded_data_provider import MultithreadedTensorProvider | |
def test_data_provider(): | |
return (np.random.random([1,28*28]).astype(np.float32), | |
np.random.random([1,10]).astype(np.float32)) | |
def fake_data_provider(): | |
return {'data':np.random.random([5,28*28]).astype(np.float32), | |
'labels':np.random.random([5,10]).astype(np.float32), | |
'key1':None, | |
'key2':None} | |
def dp(): | |
batch = fake_data_provider() | |
#some preprocessing | |
for i in range(4): | |
np.random.random([100,100]) | |
return (batch['data'], batch['labels']) | |
sess = tf.Session() | |
tensor_provider = MultithreadedTensorProvider(capacity=10, sess=sess, | |
dtypes=[tf.float32, tf.float32], shuffle_queue=False, number_of_threads=3) | |
images_batch, labels_batch = tensor_provider.get_input() | |
w = tf.get_variable("w1", [28*28, 10]) | |
y_pred = tf.matmul(images_batch, w) | |
loss = tf.nn.softmax_cross_entropy_with_logits(logits=y_pred, labels=labels_batch) | |
loss_mean = tf.reduce_mean(loss) | |
train_op = tf.train.AdamOptimizer(0.005).minimize(loss) | |
init = tf.global_variables_initializer() | |
sess.run(init) | |
print('!!!!!!!!!!!!!!!!!!!!!!!!!!start train') | |
tensor_provider.set_data_provider(dp) | |
for i in range(100): | |
print() | |
print('q_size', tensor_provider.get_queue_size()) | |
print(sess.run(labels_batch).shape) | |
print('q_size',tensor_provider.get_queue_size()) | |
_, loss_val = sess.run([train_op, loss_mean]) | |
if tensor_provider.get_queue_size() == 0: | |
warnings.warn("Queue is empty!\ | |
Try to increase queue capasity and/or number_of_threads") | |
print('!!!!!!!!!!!!!!!!!!!!!!!Train finish') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment