Created
January 23, 2018 07:21
-
-
Save Multihuntr/d7d1d1a81d2621ec622468a3ec8effca to your computer and use it in GitHub Desktop.
A minimal example of how you can accumulate gradients across batches, allowing you to train using much larger batch sizes than can fit in memory at the cost of speed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
import sys | |
from tensorflow.examples.tutorials.mnist import input_data | |
n_pseudo_batches = int(sys.argv[1]) if len(sys.argv) > 1 else 128 | |
actual_batch_size = int(sys.argv[2]) if len(sys.argv) > 2 else 32 | |
iterations = int(sys.argv[3]) if len(sys.argv) > 3 else 10 | |
tf.set_random_seed(147258) | |
np.random.seed(123456) | |
def simple_model(input): | |
# These initializers ensure that the model will always be instantiated the same, for comparison. | |
hidden_initializer = tf.constant_initializer(np.random.uniform(-0.025, 0.025, size=[784,100])) | |
hidden = tf.layers.dense(input, 100, kernel_initializer=hidden_initializer) | |
out_initializer = tf.constant_initializer(np.random.uniform(-0.025, 0.025, size=[100,10])) | |
return tf.layers.dense(tf.nn.relu(hidden), 10, kernel_initializer=out_initializer) | |
inp = tf.placeholder(tf.float32, [None,784]) | |
targ = tf.placeholder(tf.float32, [None,10]) | |
# Make our model and optimizer and gradients | |
out = simple_model(inp) | |
opt = tf.train.GradientDescentOptimizer(learning_rate=1e-1) | |
loss = tf.losses.mean_squared_error(out, targ) | |
# standard gradients for a batch | |
t_vars = tf.trainable_variables() | |
grads, graph_vars = zip(*opt.compute_gradients(loss, t_vars)) | |
# IMPORTANT: Make sure you call the zero_ops to reset the accumulated gradients | |
# tl;dr - Add the below section to your code to accumulate gradients | |
# ----------------------------------------------------------------------------------------------------- | |
# Define our divisor, used to normalise gradients across pseudo_batches | |
divisor = tf.Variable(0, trainable=False) | |
div_fl = tf.to_float(divisor) | |
zero_divisor = divisor.assign(0) | |
inc_divisor = divisor.assign(divisor+1) | |
# Accumulation ops and variables | |
# create a copy of all trainable variables with `0` as initial values | |
accum_grads = [tf.Variable(tf.zeros_like(t_var.initialized_value()), trainable=False) for t_var in t_vars] | |
# create an op to zero all accums vars (and zero the divisor again) | |
with tf.control_dependencies([zero_divisor]): | |
zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_grads] | |
# Create ops for accumulating the gradient (also adds one to the final divisor) | |
with tf.control_dependencies([inc_divisor]): | |
accum_ops = [accum_grad.assign_add(grad) for (accum_grad, grad) in zip(accum_grads, grads)] | |
# Create op that updates the weights (also divides accumulated gradients by the number of steps) | |
normalised_accum_grads = [accum_grad/div_fl for (accum_grad) in accum_grads] | |
# ------------------------------------------------------------------------------------------------------ | |
train_op = opt.apply_gradients(zip(normalised_accum_grads, graph_vars)) | |
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True, seed=764847) | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
for x in range(iterations): | |
iteration_loss = 0 | |
for y in range(n_pseudo_batches): | |
inp_, targ_ = mnist.train.next_batch(actual_batch_size) | |
_, loss_ = sess.run((accum_ops, loss), {inp: inp_, targ: targ_}) | |
iteration_loss += loss_ | |
# To find actual loss, you need to divide by the number of pseudo batches | |
print(iteration_loss/n_pseudo_batches) | |
sess.run(train_op) | |
# vvvv --- VERY IMPORTANT! --- vvvv | |
sess.run(zero_ops) | |
That's a very nice implementation for something that I needed. Thanks for sharing!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
See my other gist for a thorough testing of this idea. The losses printed when running are almost exactly the same, but please refer to the other gist to convince yourself that the accumulated gradients are accurate.
Be warned: it's a LOT slower. A batch of 16 takes the same amount of time as a batch of 32, so an
actual_batch_size=16
,n_pseudo_batches=2
takes twice as long as anactual_batch_size=32
,n_pseudo_batches=1
(the normal case). This is purely for if you absolutely need to fit more examples in memory.