Skip to content

Instantly share code, notes, and snippets.

@skaae
Created June 8, 2015 11:05
Show Gist options
  • Save skaae/832e1df8b2cdd930e93e to your computer and use it in GitHub Desktop.
Save skaae/832e1df8b2cdd930e93e to your computer and use it in GitHub Desktop.
def get_output_for(self, input, mask=None, **kwargs):
"""
Compute this layer's output function given a symbolic input variable
Parameters
----------
input : theano.TensorType
Symbolic input variable.
mask : theano.TensorType
Theano variable denoting whether each time step in each
sequence in the batch is part of the sequence or not. If ``None``,
then it assumed that all sequences are of the same length. If
not all sequences are of the same length, then it must be
supplied as a matrix of shape ``(n_batch, n_time_steps)`` where
``mask[i, j] = 1`` when ``j <= (length of sequence i)`` and
``mask[i, j] = 0`` when ``j > (length of sequence i)``.
Returns
-------
layer_output : theano.TensorType
Symblic output variable.
"""
# Treat all dimensions after the second as flattened feature dimensions
if input.ndim > 3:
input = input.reshape((input.shape[0], input.shape[1],
T.prod(input.shape[2:])))
# Because scan iterates over the first dimension we dimshuffle to
# (n_time_steps, n_batch, n_features)
input = input.dimshuffle(1, 0, 2)
seq_len, num_batch, _ = input.shape
# Because the input is given for all time steps, we can precompute_input
# the inputs dot weight matrices before scanning.
# W_in_stacked is (n_features, 4*num_units). input is then
# (n_time_steps, n_batch, 4*num_units).
if self.precompute_input:
input = T.dot(input, self.W_in_stacked) + self.b_stacked
# input is (n_batch, n_time_steps, 4*num_units). We define a
# slicing function that extract the input to each LSTM gate
def slice_w(x, n):
return x[:, n*self.num_units:(n+1)*self.num_units]
# Create single recurrent computation step function
# input_dot_W_n is the nth timestep of the input, dotted with W
# The step function calculates the following:
#
# i_t = \sigma(W_{xi}x_t + W_{hi}h_{t-1} + W_{ci}c_{t-1} + b_i)
# f_t = \sigma(W_{xf}x_t + W_{hf}h_{t-1} + W_{cf}c_{t-1} + b_f)
# c_t = f_tc_{t - 1} + i_t\tanh(W_{xc}x_t + W_{hc}h_{t-1} + b_c)
# o_t = \sigma(W_{xo}x_t + W_{ho}h_{t-1} + W_{co}c_t + b_o)
# h_t = o_t \tanh(c_t)
def step(input_n, cell_previous, hid_previous, W_hid_stacked,
*args):
if self.peepholes:
[W_cell_to_ingate,
W_cell_to_forgetgate,
W_cell_to_outgate] = args[:3]
if not self.precompute_input:
W_in_stacked, b_stacked = args[-2:]
input_n = T.dot(input_n, W_in_stacked) + b_stacked
# Calculate gates pre-activations and slice
gates = input_n + T.dot(hid_previous, W_hid_stacked)
# clip gradients
if self.grad_clipping is not False:
gates = theano.gradient.grad_clip(
gates, -self.grad_clipping, self.grad_clipping)
# Extract the pre-activation gate values
ingate = slice_w(gates, 0)
forgetgate = slice_w(gates, 1)
cell_input = slice_w(gates, 2)
outgate = slice_w(gates, 3)
if self.peepholes:
# Compute peephole connections
ingate += cell_previous*W_cell_to_ingate
forgetgate += cell_previous*W_cell_to_forgetgate
# Apply nonlinearities
ingate = self.nonlinearity_ingate(ingate)
forgetgate = self.nonlinearity_forgetgate(forgetgate)
cell_input = self.nonlinearity_cell(cell_input)
outgate = self.nonlinearity_outgate(outgate)
# Compute new cell value
cell = forgetgate*cell_previous + ingate*cell_input
if self.peepholes:
outgate += cell*W_cell_to_outgate
# Compute new hidden unit activation
hid = outgate*self.nonlinearity_out(cell)
return [cell, hid]
def step_masked(input_dot_W_n, mask_n, cell_previous, hid_previous,
W_hid_stacked, *args):
cell, hid = step(input_dot_W_n, cell_previous, hid_previous,
W_hid_stacked, *args)
# If mask is 0, use previous state until mask = 1 is found.
# This propagates the layer initial state when moving backwards
# until the end of the sequence is found.
not_mask = 1 - mask_n
cell = cell*mask_n + cell_previous*not_mask
hid = hid*mask_n + hid_previous*not_mask
return [cell, hid]
if mask is not None:
# mask is given as (batch_size, seq_len). Because scan iterates
# over first dimension, we dimshuffle to (seq_len, batch_size) and
# add a broadcastable dimension
mask = mask.dimshuffle(1, 0, 'x')
sequences = [input, mask]
step_fun = step_masked
else:
sequences = input
step_fun = step
ones = T.ones((num_batch, 1))
if isinstance(self.cell_init, T.TensorVariable):
cell_init = self.cell_init
else:
cell_init = T.dot(ones, self.cell_init) # repeat num_batch times
if isinstance(self.hid_init, T.TensorVariable):
hid_init = self.hid_init
else:
hid_init = T.dot(ones, self.hid_init) # repeat num_batch times
non_seqs = [self.W_hid_stacked]
if self.peepholes:
non_seqs += [self.W_cell_to_ingate,
self.W_cell_to_forgetgate,
self.W_cell_to_outgate]
if not self.precompute_input:
non_seqs += [self.W_in_stacked, self.b_stacked]
if self.unroll_scan:
# use for loop to unroll recursion.
cell_out, hid_out = unroll_scan(
fn=step_fun,
sequences=sequences,
outputs_info=[cell_init, hid_init],
go_backwards=self.backwards,
non_sequences=non_seqs,
n_steps=self.input_shape[1])
else:
# Scan op iterates over first dimension of input and repeatedly
# applies the step function
cell_out, hid_out = theano.scan(
fn=step_fun,
sequences=sequences,
outputs_info=[cell_init, hid_init],
go_backwards=self.backwards,
truncate_gradient=self.gradient_steps,
non_sequences=non_seqs,
strict=True)[0]
# dimshuffle back to (n_batch, n_time_steps, n_features))
hid_out = hid_out.dimshuffle(1, 0, 2)
cell_out = cell_out.dimshuffle(1, 0, 2)
# if scan is backward reverse the output
if self.backwards:
hid_out = hid_out[:, ::-1]
cell_out = cell_out[:, ::-1]
self.hid_out = hid_out
self.cell_out = cell_out
return hid_out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment