Created
October 28, 2017 13:38
-
-
Save bzamecnik/05ff3bda1f82b16e325a3a1857fd2994 to your computer and use it in GitHub Desktop.
TF StagingArea with variable batch size
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Does StagingArea support batches of variable size? Yes. | |
# | |
# The training or validation set might not be exactly divisible by the batch | |
# size. Thus at the end one batch might be smaller. We can either ignore | |
# (incorrect with respect to the loss) it or provide batches with variable size. | |
# On the other hand we'd like to ensure the data points have the same shape. | |
# | |
# It turns out | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow.python.ops.data_flow_ops import StagingArea | |
# Basic StagingArea with constant batch size | |
with tf.Session() as sess: | |
# batches of 4 samples with shape (1,) | |
area = StagingArea(dtypes=[tf.int32], shapes=[(4, 1)]) | |
sess.run(area.put([tf.constant([[1],[2],[3],[4]])])) | |
# smaller batch | |
try: | |
sess.run(area.put([tf.constant([[5],[6]])])) | |
except ValueError as ex: | |
assert ex.args[0] == "Shapes (2, 1) and (4, 1) are incompatible" | |
# different data sample shape | |
try: | |
sess.run(area.put([tf.constant([[1,1], [2,2], [3,3], [4,4]])])) | |
except ValueError as ex: | |
assert ex.args[0] == "Shapes (4, 2) and (4, 1) are incompatible" | |
# StagingArea support batches of variable size | |
with tf.Session() as sess: | |
area = StagingArea(dtypes=[tf.int32]) | |
sess.run(area.put([tf.constant([1,2,3,4])])) | |
# smaller batch | |
sess.run(area.put([tf.constant([5,6])])) | |
assert np.all(sess.run(area.get()) == [1,2,3,4]) | |
assert np.all(sess.run(area.get()) == [5,6]) | |
# data sample shape not enforced :/ | |
sess.run(area.put([tf.constant([[5,5], [6,6]])])) | |
assert np.all(sess.run(area.get()) == [[5,5], [6,6]]) | |
# StagingArea supports batches of variable number of samples, each with the same shape. | |
with tf.Session() as sess: | |
area = StagingArea(dtypes=[tf.int32], shapes=[(None, 2)]) | |
# normal batch | |
sess.run(area.put([tf.constant([[1,1],[2,2],[3,3],[4,4]])])) | |
# smaller batch (eg. at the end of training set) | |
sess.run(area.put([tf.constant([[5,5],[6,6]])])) | |
assert np.all(sess.run(area.get()) == [[1,1],[2,2],[3,3],[4,4]]) | |
assert np.all(sess.run(area.get()) == [[5,5],[6,6]]) | |
# data sample shape not enforced :/ | |
try: | |
sess.run(area.put([tf.constant([[5,5,5], [6,6,6]])])) | |
except ValueError as ex: | |
assert ex.args[0] == "Shapes (2, 3) and (?, 2) are incompatible" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment