Skip to content

Instantly share code, notes, and snippets.

@skaae
Created September 21, 2015 20:28
Show Gist options
  • Save skaae/9cca83efe1ae7de93450 to your computer and use it in GitHub Desktop.
Save skaae/9cca83efe1ae7de93450 to your computer and use it in GitHub Desktop.
import numpy as np
import theano
import theano.tensor as T
import lasagne.nonlinearities as nonlinearities
import lasagne.init as init
from lasagne.utils import unroll_scan
from lasagne.layers import *
import lasagne.layers.helper as helper
class ListIndexLayer(Layer):
def __init__(self, incoming, index, **kwargs):
super(ListIndexLayer, self).__init__(incoming, **kwargs)
self.index = index
def get_output_for(self, input, **kwargs):
return input[self.index]
class LSTMLayerStep(MergeLayer):
def __init__(self, incoming, recurrent_hid_in, recurrent_cell_in,
num_units,
ingate=Gate(),
forgetgate=Gate(),
cell=Gate(W_cell=None, nonlinearity=nonlinearities.tanh),
outgate=Gate(),
nonlinearity=nonlinearities.tanh,
cell_init=init.Constant(0.),
hid_init=init.Constant(0.),
learn_init=False,
peepholes=True,
**kwargs):
# This layer inherits from a MergeLayer, because it can have two
# inputs - the layer input, and the mask. We will just provide the
# layer input as incomings, unless a mask input was provided.
incomings = [incoming, recurrent_hid_in, recurrent_cell_in]
# Initialize parent layer
super(LSTMLayerStep, self).__init__(incomings, **kwargs)
# If the provided nonlinearity is None, make it linear
if nonlinearity is None:
self.nonlinearity = nonlinearities.identity
else:
self.nonlinearity = nonlinearity
self.learn_init = learn_init
self.num_units = num_units
self.peepholes = peepholes
# Retrieve the dimensionality of the incoming layer
input_shape = self.input_shapes[0]
num_inputs = np.prod(input_shape[1:])
def add_gate_params(gate, gate_name):
""" Convenience function for adding layer parameters from a Gate
instance. """
return (self.add_param(gate.W_in, (num_inputs, num_units),
name="W_in_to_{}".format(gate_name)),
self.add_param(gate.W_hid, (num_units, num_units),
name="W_hid_to_{}".format(gate_name)),
self.add_param(gate.b, (num_units,),
name="b_{}".format(gate_name),
regularizable=False),
gate.nonlinearity)
# Add in parameters from the supplied Gate instances
(self.W_in_to_ingate, self.W_hid_to_ingate, self.b_ingate,
self.nonlinearity_ingate) = add_gate_params(ingate, 'ingate')
(self.W_in_to_forgetgate, self.W_hid_to_forgetgate, self.b_forgetgate,
self.nonlinearity_forgetgate) = add_gate_params(forgetgate,
'forgetgate')
(self.W_in_to_cell, self.W_hid_to_cell, self.b_cell,
self.nonlinearity_cell) = add_gate_params(cell, 'cell')
(self.W_in_to_outgate, self.W_hid_to_outgate, self.b_outgate,
self.nonlinearity_outgate) = add_gate_params(outgate, 'outgate')
# If peephole (cell to gate) connections were enabled, initialize
# peephole connections. These are elementwise products with the cell
# state, so they are represented as vectors.
if self.peepholes:
self.W_cell_to_ingate = self.add_param(
ingate.W_cell, (num_units, ), name="W_cell_to_ingate")
self.W_cell_to_forgetgate = self.add_param(
forgetgate.W_cell, (num_units, ), name="W_cell_to_forgetgate")
self.W_cell_to_outgate = self.add_param(
outgate.W_cell, (num_units, ), name="W_cell_to_outgate")
# Setup initial values for the cell and the hidden units
if isinstance(cell_init, T.TensorVariable):
if cell_init.ndim != 2:
raise ValueError(
"When cell_init is provided as a TensorVariable, it should"
" have 2 dimensions and have shape (num_batch, num_units)")
self.cell_init = cell_init
else:
self.cell_init = self.add_param(
cell_init, (1, num_units), name="cell_init",
trainable=learn_init, regularizable=False)
if isinstance(hid_init, T.TensorVariable):
if hid_init.ndim != 2:
raise ValueError(
"When hid_init is provided as a TensorVariable, it should "
"have 2 dimensions and have shape (num_batch, num_units)")
self.hid_init = hid_init
else:
self.hid_init = self.add_param(
hid_init, (1, self.num_units), name="hid_init",
trainable=learn_init, regularizable=False)
# stack matrices
# can this be moved inside get_output_for?
self.W_in_stacked = T.concatenate(
[self.W_in_to_ingate, self.W_in_to_forgetgate,
self.W_in_to_cell, self.W_in_to_outgate], axis=1)
# Same for hidden weight matrices
self.W_hid_stacked = T.concatenate(
[self.W_hid_to_ingate, self.W_hid_to_forgetgate,
self.W_hid_to_cell, self.W_hid_to_outgate], axis=1)
# Stack biases into a (4*num_units) vector
self.b_stacked = T.concatenate(
[self.b_ingate, self.b_forgetgate,
self.b_cell, self.b_outgate], axis=0)
def get_output_shape_for(self, input_shapes):
return input_shapes[0][0], input_shapes[0][1]
def slice_w(self, x, n):
return x[:, n*self.num_units:(n+1)*self.num_units]
def get_recurrent_inits(self, num_batch):
print "hardcoded batch_size"
ones = T.ones((10, 1))
if isinstance(self.cell_init, T.TensorVariable):
cell_init = self.cell_init
else:
# Dot against a 1s vector to repeat to shape (num_batch, num_units)
cell_init = T.dot(ones, self.cell_init)
if isinstance(self.hid_init, T.TensorVariable):
hid_init = self.hid_init
else:
# Dot against a 1s vector to repeat to shape (num_batch, num_units)
hid_init = T.dot(ones, self.hid_init)
hid_init.name = 'LSTM_HID_INIT'
cell_init.name = 'LSTM_CELL_INIT'
return hid_init, cell_init
def get_output_for(self, inputs, **kwargs):
x_previous, hid_previous, cell_previous = inputs
num_batch, _ = x_previous.shape
x_mix = T.dot(x_previous, self.W_in_stacked) + self.b_stacked
# Calculate gates pre-activations and slice
gates = x_mix + T.dot(hid_previous, self.W_hid_stacked)
# Extract the pre-activation gate values
ingate = self.slice_w(gates, 0)
forgetgate = self.slice_w(gates, 1)
cell_input = self.slice_w(gates, 2)
outgate = self.slice_w(gates, 3)
if self.peepholes:
# Compute peephole connections
ingate += cell_previous*self.W_cell_to_ingate
forgetgate += cell_previous*self.W_cell_to_forgetgate
# Apply nonlinearities
ingate = self.nonlinearity_ingate(ingate)
forgetgate = self.nonlinearity_forgetgate(forgetgate)
cell_input = self.nonlinearity_cell(cell_input)
outgate = self.nonlinearity_outgate(outgate)
# Compute new cell value
cell = forgetgate*cell_previous + ingate*cell_input
if self.peepholes:
outgate += cell*self.W_cell_to_outgate
# Compute new hidden unit activation
hid = outgate*self.nonlinearity(cell)
return [cell, hid]
class RecurrentContainerLayer(MergeLayer):
def __init__(self, incoming,
step,
step_input,
recurrent_connections,
backwards=False,
gradient_steps=-1,
unroll_scan=False,
mask_input=None,
**kwargs):
# This layer inherits from a MergeLayer, because it can have two
# inputs - the layer input, and the mask. We will just provide the
# layer input as incomings, unless a mask input was provided.
incomings = [incoming]
if mask_input is not None:
incomings.append(mask_input)
# Initialize parent layer
super(RecurrentContainerLayer, self).__init__(incomings, **kwargs)
self.step = step
self.step_input = step_input
self.num_units = num_units
self.backwards = backwards
self.gradient_steps = gradient_steps
self.unroll_scan = unroll_scan
self.recurrent_connections = recurrent_connections
if unroll_scan and gradient_steps != -1:
raise ValueError(
"Gradient steps must be -1 when unroll_scan is true.")
# Retrieve the dimensionality of the incoming layer
input_shape = self.input_shapes[0]
if unroll_scan and input_shape[1] is None:
raise ValueError("Input sequence length cannot be specified as "
"None when unroll_scan is True")
def get_output_shape_for(self, input_shapes):
input_shape = input_shapes[0]
return input_shape[0], input_shape[1], self.step.output_shape[-1]
def get_params(self, **tags):
return helper.get_all_params(self.step, **tags)
def get_output_for(self, inputs, **kwargs):
# Retrieve the layer input
input = inputs[0]
# Retrieve the mask when it is supplied
mask = inputs[1] if len(inputs) > 1 else None
# Treat all dimensions after the second as flattened feature dimensions
if input.ndim > 3:
input = T.flatten(input, 3)
# Because scan iterates over the first dimension we dimshuffle to
# (n_time_steps, n_batch, n_features)
input = input.dimshuffle(1, 0, 2)
seq_len, num_batch, _ = input.shape
num_inits = len(self.recurrent_connections)
# recurrent_connections = [gru_recurrent_in, ]
# gru_recurrent_in : gru_layer,
# lstm_recurrent_in_hid: lstm_hid,
# lstm_recurrent_in_cell: lstm_cell}
mappings = self.recurrent_connections.items()
input_layers, output_layers = map(list, zip(*mappings))
# collect inits WILL HAVE TO BE REWRITTEN
layers = helper.get_all_layers(self.step)
inits = [l.get_recurrent_inits(num_batch) for l in layers if hasattr(l, 'get_recurrent_inits')]
inits = [item for sublist in inits for item in sublist]
input_layers = [self.step_input] + input_layers
def step(input_n, *args):
# Compute W_{hr} h_{t - 1}, W_{hu} h_{t - 1}, and W_{hc} h_{t - 1}
args = list(args)
# plus 1 to include x
previous_values = args[:num_inits]
weights = args[num_inits:]
step_map = {input_layer: input for
input_layer, input in
zip(input_layers, [input_n] + previous_values)}
outputs = helper.get_output([self.step] + output_layers, step_map
)
return outputs
def step_masked(input_n, mask_n, *args):
previous_outputs = args
outputs = step(input_n, args)
not_mask = 1 - mask_n
masked_output = []
for output, previous in zip(outputs, previous_outputs):
masked_output += [output*mask_n + previous*not_mask]
return masked_output
if mask is not None:
# mask is given as (batch_size, seq_len). Because scan iterates
# over first dimension, we dimshuffle to (seq_len, batch_size) and
# add a broadcastable dimension
mask = mask.dimshuffle(1, 0, 'x')
sequences = [input, mask]
step_fun = step_masked
else:
sequences = [input]
step_fun = step
## find better way to filter inits
non_sequences = [p for p in self.get_params() if 'init' not in p.name]
if self.unroll_scan:
# Retrieve the dimensionality of the incoming layer
input_shape = self.input_shapes[0]
# Explicitly unroll the recurrence instead of using scan
hid_out = unroll_scan(
fn=step_fun,
sequences=sequences,
outputs_info=inits,
go_backwards=self.backwards,
non_sequences=[],
n_steps=input_shape[1])[0]
else:
# Scan op iterates over first dimension of input and repeatedly
# applies the step function
hid_out = theano.scan(
fn=step_fun,
sequences=sequences,
go_backwards=self.backwards,
outputs_info=inits,
non_sequences=None,#non_sequences,
truncate_gradient=self.gradient_steps,
strict=False)[0][0]
# dimshuffle back to (n_batch, n_time_steps, n_features))
hid_out = hid_out.dimshuffle(1, 0, 2)
# if scan is backward reverse the output
if self.backwards:
hid_out = hid_out[:, ::-1, :]
return hid_out
if __name__ == '__main__':
theano.config.compute_test_value = 'raise'
sym_x = T.matrix('x')
sym_hid = T.matrix('hid')
sym_cell = T.matrix('cell')
batch_size, num_inputs, num_units = 10, 12, 15
seq_len = 5
x_test = np.ones((batch_size, num_inputs), dtype='float32')
hid_test = cell_test = np.ones((batch_size, num_units), dtype='float32')
sym_x.tag.test_value = x_test
sym_hid.tag.test_value = hid_test
sym_cell.tag.test_value = cell_test
lstm_x_in = InputLayer((None, num_inputs))
recurrent_hid_in = InputLayer((None, num_inputs))
recurrent_cell_in = InputLayer((None, num_inputs))
lstm_step = LSTMLayerStep(incoming=lstm_x_in,
recurrent_hid_in=recurrent_hid_in,
recurrent_cell_in=recurrent_cell_in,
num_units=num_units)
lstm_hid = ListIndexLayer(lstm_step, index=0)
lstm_cell = ListIndexLayer(lstm_step, index=1)
hid, cell = helper.get_output(lstm_step, {lstm_x_in: sym_x,
recurrent_hid_in: sym_hid,
recurrent_cell_in: sym_cell})
hid_out = hid.eval({sym_x: x_test,
sym_hid: hid_test,
sym_cell: cell_test})
cell_out = cell.eval({sym_x: x_test,
sym_hid: hid_test,
sym_cell: cell_test})
print "hid_out", hid_out.shape
print "cell_out", cell_out.shape
l_in = InputLayer((None, seq_len, num_inputs))
connections = {recurrent_hid_in: lstm_hid,
recurrent_cell_in: lstm_cell}
sym_x = T.tensor3()
sym_x.tag.test_value = np.ones((batch_size, seq_len, num_inputs), dtype='float32')
l_lstm = RecurrentContainerLayer(
incoming=l_in,
step=lstm_hid,
step_input=lstm_x_in,
recurrent_connections=connections, unroll_scan=False)
l_lstm_output = helper.get_output(l_lstm, {l_in: sym_x})
print l_lstm_output.eval({sym_x: sym_x.tag.test_value}).shape
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment