Last active
May 12, 2017 09:10
-
-
Save avostryakov/2e6d2b0055bc0d6fdbbf6d8d5b4ec303 to your computer and use it in GitHub Desktop.
Implementation of Recurrent Batch Normalization article in Lasagne
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from lasagne import nonlinearities, init | |
from lasagne.layers.normalization import BatchNormLayer | |
from lasagne.layers.recurrent import Gate, Layer, MergeLayer, LSTMLayer | |
from lasagne.utils import unroll_scan | |
import numpy as np | |
import theano | |
import theano.tensor as T | |
class BatchNormalizedLSTMLayer(LSTMLayer): | |
def __init__(self, incoming, num_units, | |
ingate=Gate(), | |
forgetgate=Gate(), | |
cell=Gate(W_cell=None, nonlinearity=nonlinearities.tanh), | |
outgate=Gate(), | |
nonlinearity=nonlinearities.tanh, | |
cell_init=init.Constant(0.), | |
hid_init=init.Constant(0.), | |
backwards=False, | |
learn_init=False, | |
peepholes=True, | |
gradient_steps=-1, | |
grad_clipping=0, | |
unroll_scan=False, | |
precompute_input=True, | |
mask_input=None, | |
only_return_final=False, | |
batch_axes=(0,), | |
**kwargs): | |
# Initialize parent layer | |
super(BatchNormalizedLSTMLayer, self).__init__(incoming, num_units, | |
ingate, forgetgate, cell, outgate, | |
nonlinearity, cell_init, hid_init, | |
backwards, learn_init, peepholes, | |
gradient_steps, grad_clipping, | |
unroll_scan, precompute_input, mask_input, | |
only_return_final, **kwargs) | |
input_shape = self.input_shapes[0] | |
# create BN layer with input shape (n_steps, batch_size, 4*num_units) and given axes | |
self.bn_input = BatchNormLayer((input_shape[1], input_shape[0], 4*self.num_units), beta=None, | |
gamma=init.Constant(0.1), axes=batch_axes) | |
self.params.update(self.bn_input.params) # make BN params your params | |
# create batch normalization parameters for hidden units; the shape is (time_steps, num_units) | |
self.epsilon = np.float32(1e-4) | |
self.alpha = theano.shared(np.float32(0.1)) | |
shape = (input_shape[1], 4*num_units) | |
self.gamma = self.add_param(init.Constant(0.1), shape, 'gamma', trainable=True, regularizable=True) | |
self.running_mean = self.add_param(init.Constant(0), (input_shape[1], 4*num_units,), 'running_mean', | |
trainable=False, regularizable=False) | |
self.running_inv_std = self.add_param(init.Constant(1), (input_shape[1], 4*num_units,), 'running_inv_std', | |
trainable=False, regularizable=False) | |
self.running_mean_clone = theano.clone(self.running_mean, share_inputs=False) | |
self.running_inv_std_clone = theano.clone(self.running_inv_std, share_inputs=False) | |
self.running_mean_clone.default_update = self.running_mean_clone | |
self.running_inv_std_clone.default_update = self.running_inv_std_clone | |
def get_output_for(self, inputs, deterministic=False, **kwargs): | |
# Retrieve the layer input | |
input = inputs[0] | |
# Retrieve the mask when it is supplied | |
mask = None | |
hid_init = None | |
cell_init = None | |
if self.mask_incoming_index > 0: | |
mask = inputs[self.mask_incoming_index] | |
if self.hid_init_incoming_index > 0: | |
hid_init = inputs[self.hid_init_incoming_index] | |
if self.cell_init_incoming_index > 0: | |
cell_init = inputs[self.cell_init_incoming_index] | |
# Treat all dimensions after the second as flattened feature dimensions | |
if input.ndim > 3: | |
input = T.flatten(input, 3) | |
# Because scan iterates over the first dimension we dimshuffle to | |
# (n_time_steps, n_batch, n_features) | |
input = input.dimshuffle(1, 0, 2) | |
seq_len, num_batch, _ = input.shape | |
# Stack input weight matrices into a (num_inputs, 4*num_units) | |
# matrix, which speeds up computation | |
W_in_stacked = T.concatenate( | |
[self.W_in_to_ingate, self.W_in_to_forgetgate, | |
self.W_in_to_cell, self.W_in_to_outgate], axis=1) | |
# Same for hidden weight matrices | |
W_hid_stacked = T.concatenate( | |
[self.W_hid_to_ingate, self.W_hid_to_forgetgate, | |
self.W_hid_to_cell, self.W_hid_to_outgate], axis=1) | |
# Stack biases into a (4*num_units) vector | |
b_stacked = T.concatenate( | |
[self.b_ingate, self.b_forgetgate, | |
self.b_cell, self.b_outgate], axis=0) | |
input = self.bn_input.get_output_for(T.dot(input, W_in_stacked)) + b_stacked | |
# At each call to scan, input_n will be (n_time_steps, 4*num_units). | |
# We define a slicing function that extract the input to each LSTM gate | |
def slice_w(x, n): | |
return x[:, n*self.num_units:(n+1)*self.num_units] | |
# Create single recurrent computation step function | |
# input_n is the n'th vector of the input | |
def step(input_n, gamma, time_step, cell_previous, hid_previous, *args): | |
hidden = T.dot(hid_previous, W_hid_stacked) | |
# batch normalization of hidden states | |
if deterministic: | |
mean = self.running_mean[time_step] | |
inv_std = self.running_inv_std[time_step] | |
else: | |
mean = hidden.mean(0) | |
inv_std = T.inv(T.sqrt(hidden.var(0) + self.epsilon)) | |
self.running_mean_clone.default_update = \ | |
T.set_subtensor(self.running_mean_clone.default_update[time_step], | |
(1-self.alpha) * self.running_mean_clone.default_update[time_step] + self.alpha * mean) | |
self.running_inv_std_clone.default_update = \ | |
T.set_subtensor(self.running_inv_std_clone.default_update[time_step], | |
(1-self.alpha) * self.running_inv_std_clone.default_update[time_step] + self.alpha * inv_std) | |
mean += 0 * self.running_mean_clone[time_step] | |
inv_std += 0 * self.running_inv_std_clone[time_step] | |
gamma = gamma.dimshuffle('x', 0) | |
mean = mean.dimshuffle('x', 0) | |
inv_std = inv_std.dimshuffle('x', 0) | |
# normalize | |
normalized = (hidden - mean) * (gamma * inv_std) | |
# Calculate gates pre-activations and slice | |
gates = input_n + normalized | |
# Clip gradients | |
if self.grad_clipping: | |
gates = theano.gradient.grad_clip( | |
gates, -self.grad_clipping, self.grad_clipping) | |
# Extract the pre-activation gate values | |
ingate = slice_w(gates, 0) | |
forgetgate = slice_w(gates, 1) | |
cell_input = slice_w(gates, 2) | |
outgate = slice_w(gates, 3) | |
if self.peepholes: | |
# Compute peephole connections | |
ingate += cell_previous*self.W_cell_to_ingate | |
forgetgate += cell_previous*self.W_cell_to_forgetgate | |
# Apply nonlinearities | |
ingate = self.nonlinearity_ingate(ingate) | |
forgetgate = self.nonlinearity_forgetgate(forgetgate) | |
cell_input = self.nonlinearity_cell(cell_input) | |
# Compute new cell value | |
cell = forgetgate*cell_previous + ingate*cell_input | |
if self.peepholes: | |
outgate += cell*self.W_cell_to_outgate | |
outgate = self.nonlinearity_outgate(outgate) | |
# Compute new hidden unit activation | |
hid = outgate*self.nonlinearity(cell) | |
return [cell, hid] | |
def step_masked(input_n, mask_n, gamma, time_step, cell_previous, hid_previous, *args): | |
cell, hid = step(input_n, gamma, time_step, cell_previous, hid_previous, *args) | |
# Skip over any input with mask 0 by copying the previous | |
# hidden state; proceed normally for any input with mask 1. | |
cell = T.switch(mask_n, cell, cell_previous) | |
hid = T.switch(mask_n, hid, hid_previous) | |
return [cell, hid] | |
if mask is not None: | |
# mask is given as (batch_size, seq_len). Because scan iterates | |
# over first dimension, we dimshuffle to (seq_len, batch_size) and | |
# add a broadcastable dimension | |
mask = mask.dimshuffle(1, 0, 'x') | |
sequences = [input, mask] | |
step_fun = step_masked | |
else: | |
sequences = [input, ] | |
step_fun = step | |
time_steps = np.asarray(np.arange(self.input_shapes[0][1]), dtype=np.int32) | |
sequences.extend([self.gamma, time_steps]) | |
ones = T.ones((num_batch, 1)) | |
if not isinstance(self.cell_init, Layer): | |
# Dot against a 1s vector to repeat to shape (num_batch, num_units) | |
cell_init = T.dot(ones, self.cell_init) | |
if not isinstance(self.hid_init, Layer): | |
# Dot against a 1s vector to repeat to shape (num_batch, num_units) | |
hid_init = T.dot(ones, self.hid_init) | |
# The hidden-to-hidden weight matrix is always used in step | |
non_seqs = [W_hid_stacked] | |
# The "peephole" weight matrices are only used when self.peepholes=True | |
if self.peepholes: | |
non_seqs += [self.W_cell_to_ingate, | |
self.W_cell_to_forgetgate, | |
self.W_cell_to_outgate] | |
non_seqs += [self.running_mean, self.running_inv_std] | |
if self.unroll_scan: | |
# Retrieve the dimensionality of the incoming layer | |
input_shape = self.input_shapes[0] | |
# Explicitly unroll the recurrence instead of using scan | |
cell_out, hid_out = unroll_scan( | |
fn=step_fun, | |
sequences=sequences, | |
outputs_info=[cell_init, hid_init], | |
go_backwards=self.backwards, | |
non_sequences=non_seqs, | |
n_steps=input_shape[1]) | |
else: | |
# Scan op iterates over first dimension of input and repeatedly | |
# applies the step function | |
cell_out, hid_out = theano.scan( | |
fn=step_fun, | |
sequences=sequences, | |
outputs_info=[cell_init, hid_init], | |
go_backwards=self.backwards, | |
truncate_gradient=self.gradient_steps, | |
non_sequences=non_seqs, | |
strict=True)[0] | |
# When it is requested that we only return the final sequence step, | |
# we need to slice it out immediately after scan is applied | |
if self.only_return_final: | |
hid_out = hid_out[-1] | |
else: | |
# dimshuffle back to (n_batch, n_time_steps, n_features)) | |
hid_out = hid_out.dimshuffle(1, 0, 2) | |
# if scan is backward reverse the output | |
if self.backwards: | |
hid_out = hid_out[:, ::-1] | |
return hid_out |
@webeng, in the discussion at Lasagne/Lasagne#577 the author mentions that the code works when unroll_scan=True. We got that exact error when attempting to run the above code without unrolling the scan.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for implementing it.
I'm getting an error when using this implementation:
Traceback (most recent call last): File "qann.py", line 660, in <module> qann.build_model() File "qann.py", line 334, in build_model network_output = lasagne.layers.get_output(network_start) File "/Applications/MAMP/htdocs/qann/env/lib/python2.7/site-packages/lasagne/layers/helper.py", line 191, in get_output all_outputs[layer] = layer.get_output_for(layer_inputs, **kwargs) File "/Applications/MAMP/htdocs/qann/models/batch_normalized_lstm_layer.py", line 236, in get_output_for strict=True)[0] File "/Applications/MAMP/htdocs/qann/env/lib/python2.7/site-packages/theano/scan_module/scan.py", line 557, in scan scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs] IndexError: failed to coerce slice entry of type TensorVariable to integer
Do you have any idea why it occurs?