Created
June 8, 2015 11:05
-
-
Save skaae/832e1df8b2cdd930e93e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_output_for(self, input, mask=None, **kwargs): | |
""" | |
Compute this layer's output function given a symbolic input variable | |
Parameters | |
---------- | |
input : theano.TensorType | |
Symbolic input variable. | |
mask : theano.TensorType | |
Theano variable denoting whether each time step in each | |
sequence in the batch is part of the sequence or not. If ``None``, | |
then it assumed that all sequences are of the same length. If | |
not all sequences are of the same length, then it must be | |
supplied as a matrix of shape ``(n_batch, n_time_steps)`` where | |
``mask[i, j] = 1`` when ``j <= (length of sequence i)`` and | |
``mask[i, j] = 0`` when ``j > (length of sequence i)``. | |
Returns | |
------- | |
layer_output : theano.TensorType | |
Symblic output variable. | |
""" | |
# Treat all dimensions after the second as flattened feature dimensions | |
if input.ndim > 3: | |
input = input.reshape((input.shape[0], input.shape[1], | |
T.prod(input.shape[2:]))) | |
# Because scan iterates over the first dimension we dimshuffle to | |
# (n_time_steps, n_batch, n_features) | |
input = input.dimshuffle(1, 0, 2) | |
seq_len, num_batch, _ = input.shape | |
# Because the input is given for all time steps, we can precompute_input | |
# the inputs dot weight matrices before scanning. | |
# W_in_stacked is (n_features, 4*num_units). input is then | |
# (n_time_steps, n_batch, 4*num_units). | |
if self.precompute_input: | |
input = T.dot(input, self.W_in_stacked) + self.b_stacked | |
# input is (n_batch, n_time_steps, 4*num_units). We define a | |
# slicing function that extract the input to each LSTM gate | |
def slice_w(x, n): | |
return x[:, n*self.num_units:(n+1)*self.num_units] | |
# Create single recurrent computation step function | |
# input_dot_W_n is the nth timestep of the input, dotted with W | |
# The step function calculates the following: | |
# | |
# i_t = \sigma(W_{xi}x_t + W_{hi}h_{t-1} + W_{ci}c_{t-1} + b_i) | |
# f_t = \sigma(W_{xf}x_t + W_{hf}h_{t-1} + W_{cf}c_{t-1} + b_f) | |
# c_t = f_tc_{t - 1} + i_t\tanh(W_{xc}x_t + W_{hc}h_{t-1} + b_c) | |
# o_t = \sigma(W_{xo}x_t + W_{ho}h_{t-1} + W_{co}c_t + b_o) | |
# h_t = o_t \tanh(c_t) | |
def step(input_n, cell_previous, hid_previous, W_hid_stacked, | |
*args): | |
if self.peepholes: | |
[W_cell_to_ingate, | |
W_cell_to_forgetgate, | |
W_cell_to_outgate] = args[:3] | |
if not self.precompute_input: | |
W_in_stacked, b_stacked = args[-2:] | |
input_n = T.dot(input_n, W_in_stacked) + b_stacked | |
# Calculate gates pre-activations and slice | |
gates = input_n + T.dot(hid_previous, W_hid_stacked) | |
# clip gradients | |
if self.grad_clipping is not False: | |
gates = theano.gradient.grad_clip( | |
gates, -self.grad_clipping, self.grad_clipping) | |
# Extract the pre-activation gate values | |
ingate = slice_w(gates, 0) | |
forgetgate = slice_w(gates, 1) | |
cell_input = slice_w(gates, 2) | |
outgate = slice_w(gates, 3) | |
if self.peepholes: | |
# Compute peephole connections | |
ingate += cell_previous*W_cell_to_ingate | |
forgetgate += cell_previous*W_cell_to_forgetgate | |
# Apply nonlinearities | |
ingate = self.nonlinearity_ingate(ingate) | |
forgetgate = self.nonlinearity_forgetgate(forgetgate) | |
cell_input = self.nonlinearity_cell(cell_input) | |
outgate = self.nonlinearity_outgate(outgate) | |
# Compute new cell value | |
cell = forgetgate*cell_previous + ingate*cell_input | |
if self.peepholes: | |
outgate += cell*W_cell_to_outgate | |
# Compute new hidden unit activation | |
hid = outgate*self.nonlinearity_out(cell) | |
return [cell, hid] | |
def step_masked(input_dot_W_n, mask_n, cell_previous, hid_previous, | |
W_hid_stacked, *args): | |
cell, hid = step(input_dot_W_n, cell_previous, hid_previous, | |
W_hid_stacked, *args) | |
# If mask is 0, use previous state until mask = 1 is found. | |
# This propagates the layer initial state when moving backwards | |
# until the end of the sequence is found. | |
not_mask = 1 - mask_n | |
cell = cell*mask_n + cell_previous*not_mask | |
hid = hid*mask_n + hid_previous*not_mask | |
return [cell, hid] | |
if mask is not None: | |
# mask is given as (batch_size, seq_len). Because scan iterates | |
# over first dimension, we dimshuffle to (seq_len, batch_size) and | |
# add a broadcastable dimension | |
mask = mask.dimshuffle(1, 0, 'x') | |
sequences = [input, mask] | |
step_fun = step_masked | |
else: | |
sequences = input | |
step_fun = step | |
ones = T.ones((num_batch, 1)) | |
if isinstance(self.cell_init, T.TensorVariable): | |
cell_init = self.cell_init | |
else: | |
cell_init = T.dot(ones, self.cell_init) # repeat num_batch times | |
if isinstance(self.hid_init, T.TensorVariable): | |
hid_init = self.hid_init | |
else: | |
hid_init = T.dot(ones, self.hid_init) # repeat num_batch times | |
non_seqs = [self.W_hid_stacked] | |
if self.peepholes: | |
non_seqs += [self.W_cell_to_ingate, | |
self.W_cell_to_forgetgate, | |
self.W_cell_to_outgate] | |
if not self.precompute_input: | |
non_seqs += [self.W_in_stacked, self.b_stacked] | |
if self.unroll_scan: | |
# use for loop to unroll recursion. | |
cell_out, hid_out = unroll_scan( | |
fn=step_fun, | |
sequences=sequences, | |
outputs_info=[cell_init, hid_init], | |
go_backwards=self.backwards, | |
non_sequences=non_seqs, | |
n_steps=self.input_shape[1]) | |
else: | |
# Scan op iterates over first dimension of input and repeatedly | |
# applies the step function | |
cell_out, hid_out = theano.scan( | |
fn=step_fun, | |
sequences=sequences, | |
outputs_info=[cell_init, hid_init], | |
go_backwards=self.backwards, | |
truncate_gradient=self.gradient_steps, | |
non_sequences=non_seqs, | |
strict=True)[0] | |
# dimshuffle back to (n_batch, n_time_steps, n_features)) | |
hid_out = hid_out.dimshuffle(1, 0, 2) | |
cell_out = cell_out.dimshuffle(1, 0, 2) | |
# if scan is backward reverse the output | |
if self.backwards: | |
hid_out = hid_out[:, ::-1] | |
cell_out = cell_out[:, ::-1] | |
self.hid_out = hid_out | |
self.cell_out = cell_out | |
return hid_out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment