Created
May 3, 2017 14:05
-
-
Save mmourafiq/88e1cb5c391f6a6a7df6371fe89f3880 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import tensorflow as tf | |
class FeederConfig(object): | |
"""The FeederConfig holds information needed to create data feeders for training and evaluating. | |
Arguments: | |
num_threads: `int`. Total number of simultaneous threads to process data. | |
max_queue: `int`. Maximum number of data stored in a queue. | |
shuffle: `bool`. If True, data will be shuffle. | |
ensure_data_order: `bool`. Ensure that data order is keeped when using | |
'next' to retrieve data (Processing will be slower). | |
""" | |
def __init__(self, num_threads=4, max_queue=32, capacity=2000, shuffle=True, ensure_data_order=False): | |
self.num_threads = num_threads | |
self.max_queue = max_queue | |
self.shuffle = shuffle | |
self.capacity = capacity | |
if ensure_data_order: | |
self.num_threads = 1 | |
self.max_queue = 1 | |
class Feeder(object): | |
"""This class manages the the background threads needed to fill a queue full of data.""" | |
class QueueIndex(object): | |
TRAIN = 0 | |
VAL = 1 | |
TEST = 2 | |
def __init__(self, inputs, outputs, config): | |
self.inputs = inputs | |
self.outputs = outputs | |
self.queue_index = tf.placeholder(dtype=tf.int32, shape=[]) | |
self.batch_size = tf.placeholder(dtype=tf.int32, shape=[]) | |
self.config = config | |
self.num_samples = 0 | |
self.step = 0 | |
self.epoch = 0 | |
self.current_iter = 0 | |
self.queue_train = None | |
self.queue_val = None | |
self.queue_test = None | |
self.queue = None | |
self.set_queue() | |
self.enqueue_train_op = self.queue_train.enqueue_many(self.placeholders) | |
self.enqueue_val_op = self.queue_val.enqueue_many(self.placeholders) | |
self.enqueue_test_op = self.queue_test.enqueue_many(self.inputs) | |
self.dequeue_op = self.queue.dequeue_many(self.batch_size) | |
tf.add_to_collection(name='queues', value=self.dequeue_op) | |
@property | |
def placeholders(self): | |
return [self.inputs, self.outputs] if self.outputs is not None else [self.inputs] | |
@staticmethod | |
def get_shape(x): | |
return x.get_shape().as_list() | |
def set_queue(self): | |
if self.config.shuffle: | |
self.queue_train = tf.RandomShuffleQueue( | |
dtypes=[x.dtype for x in self.placeholders], | |
shapes=[self.get_shape(x)[1:] for x in self.placeholders], | |
capacity=self.config.capacity, min_after_dequeue=1000) | |
else: | |
self.queue_train = tf.FIFOQueue(dtypes=[x.dtype for x in self.placeholders], | |
shapes=[self.get_shape(x)[1:] for x in self.placeholders], | |
capacity=2000) | |
self.queue_val = tf.FIFOQueue(dtypes=[x.dtype for x in self.placeholders], | |
shapes=[self.get_shape(x)[1:] for x in self.placeholders], | |
capacity=self.config.capacity) | |
self.queue_test = tf.FIFOQueue(dtypes=[x.dtype for x in self.placeholders], | |
shapes=[self.get_shape(x)[1:] for x in self.placeholders], | |
capacity=self.config.capacity) | |
self.queue = tf.QueueBase.from_list(index=self.queue_index, queues=[self.queue_train, | |
self.queue_val, | |
self.queue_test]) | |
def _update_counters(self, index, batch_size): | |
if index != self.QueueIndex.TRAIN: | |
return | |
self.step += 1 | |
self.current_iter = min(self.step * batch_size, self.num_samples) | |
if self.current_iter == self.num_samples: | |
self.epoch += 1 | |
self.step = 0 | |
def get_inputs(self, session, queue_index, batch_size): | |
"""Return's tensors containing a batch of X and y or a batch of X. | |
If the Feeder is used for evaluation and only X is enqueued and dequeued. | |
""" | |
self._update_counters(queue_index, batch_size) | |
X_batch, y_batch = session.run(self.dequeue_op, | |
{self.queue_index: queue_index, self.batch_size: batch_size}) | |
return X_batch, y_batch | |
def get_enqueue_op(self, queue_index): | |
enqueue_op = None | |
if queue_index == self.QueueIndex.TRAIN: | |
enqueue_op = self.enqueue_train_op | |
elif queue_index == self.QueueIndex.VAL: | |
enqueue_op = self.enqueue_val_op | |
elif queue_index == self.QueueIndex.TEST: | |
enqueue_op = self.enqueue_test_op | |
return enqueue_op | |
def get_queue(self, queue_index): | |
queue = None | |
if queue_index == self.QueueIndex.TRAIN: | |
queue = self.queue_train | |
elif queue_index == self.QueueIndex.VAL: | |
queue = self.queue_val | |
elif queue_index == self.QueueIndex.TEST: | |
queue = self.queue_test | |
return queue | |
def close_queue(self, session): | |
self.queue.close(cancel_pending_enqueues=True) | |
session.run(self.queue.close(), {self.queue_index: self.QueueIndex.TRAIN}) | |
session.run(self.queue.close(), {self.queue_index: self.QueueIndex.VAL}) | |
session.run(self.queue.close(), {self.queue_index: self.QueueIndex.TEST}) | |
def enqueue(self, queue_index, X, y=None): | |
queue = self.get_queue(queue_index) | |
enqueue_op = queue.enqueue_many([X, y] if y is not None else [X]) | |
queue_runner = tf.train.QueueRunner(queue=self.queue, enqueue_ops=[enqueue_op]) | |
tf.train.add_queue_runner(queue_runner) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment