Skip to content

Instantly share code, notes, and snippets.

@jjallaire
Created October 26, 2018 13:15
Show Gist options
  • Save jjallaire/92740db8588ce5e62bc8863e487c2134 to your computer and use it in GitHub Desktop.
Save jjallaire/92740db8588ce5e62bc8863e487c2134 to your computer and use it in GitHub Desktop.
Custom Multiplicative LSTM Layer for R Keras
from __future__ import absolute_import
import numpy as np
__all__ = ['MultiplicativeLSTM']
from keras import backend as K
from keras import activations
from keras import initializers
from keras import regularizers
from keras import constraints
from keras.engine import Layer
from keras.engine import InputSpec
from keras.legacy import interfaces
from keras.layers import Recurrent
class MultiplicativeLSTM(Recurrent):
"""Multiplicative Long-Short Term Memory unit - https://arxiv.org/pdf/1609.07959.pdf
# Arguments
units: Positive integer, dimensionality of the output space.
activation: Activation function to use
(see [activations](../activations.md)).
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
for the recurrent step
(see [activations](../activations.md)).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs.
(see [initializers](../initializers.md)).
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
used for the linear transformation of the recurrent state.
(see [initializers](../initializers.md)).
bias_initializer: Initializer for the bias vector
(see [initializers](../initializers.md)).
unit_forget_bias: Boolean.
If True, add 1 to the bias of the forget gate at initialization.
Setting it to true will also force `bias_initializer="zeros"`.
This is recommended in [Jozefowicz et al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix
(see [regularizer](../regularizers.md)).
recurrent_regularizer: Regularizer function applied to
the `recurrent_kernel` weights matrix
(see [regularizer](../regularizers.md)).
bias_regularizer: Regularizer function applied to the bias vector
(see [regularizer](../regularizers.md)).
activity_regularizer: Regularizer function applied to
the output of the layer (its "activation").
(see [regularizer](../regularizers.md)).
kernel_constraint: Constraint function applied to
the `kernel` weights matrix
(see [constraints](../constraints.md)).
recurrent_constraint: Constraint function applied to
the `recurrent_kernel` weights matrix
(see [constraints](../constraints.md)).
bias_constraint: Constraint function applied to the bias vector
(see [constraints](../constraints.md)).
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the recurrent state.
# References
- [Long short-term memory](http://www.bioinf.jku.at/publications/older/2604.pdf) (original 1997 paper)
- [Learning to forget: Continual prediction with MultiplicativeLSTM](http://www.mitpressjournals.org/doi/pdf/10.1162/089976600300015015)
- [Supervised sequence labeling with recurrent neural networks](http://www.cs.toronto.edu/~graves/preprint.pdf)
- [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287)
"""
@interfaces.legacy_recurrent_support
def __init__(self, units,
activation='tanh',
recurrent_activation='hard_sigmoid',
use_bias=True,
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
unit_forget_bias=True,
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.,
recurrent_dropout=0.,
implementation=1,
**kwargs):
super(MultiplicativeLSTM, self).__init__(**kwargs)
self.units = units
self.activation = activations.get(activation)
self.recurrent_activation = activations.get(recurrent_activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.recurrent_initializer = initializers.get(recurrent_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.unit_forget_bias = unit_forget_bias
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.recurrent_constraint = constraints.get(recurrent_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
self.state_spec = [InputSpec(shape=(None, self.units)),
InputSpec(shape=(None, self.units))]
self.state_size = (self.units, self.units)
self.implementation = implementation
def build(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
batch_size = input_shape[0] if self.stateful else None
self.input_dim = input_shape[2]
self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim))
self.states = [None, None]
if self.stateful:
self.reset_states()
self.kernel = self.add_weight(shape=(self.input_dim, self.units * 5),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units * 5),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
if self.use_bias:
if self.unit_forget_bias:
def bias_initializer(shape, *args, **kwargs):
return K.concatenate([
self.bias_initializer((self.units,), *args, **kwargs),
initializers.Ones()((self.units,), *args, **kwargs),
self.bias_initializer((self.units * 3,), *args, **kwargs),
])
else:
bias_initializer = self.bias_initializer
self.bias = self.add_weight(shape=(self.units * 5,),
name='bias',
initializer=bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
self.kernel_i = self.kernel[:, :self.units]
self.kernel_f = self.kernel[:, self.units: self.units * 2]
self.kernel_c = self.kernel[:, self.units * 2: self.units * 3]
self.kernel_o = self.kernel[:, self.units * 3: self.units * 4]
self.kernel_m = self.kernel[:, self.units * 4:]
self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units]
self.recurrent_kernel_f = self.recurrent_kernel[:, self.units: self.units * 2]
self.recurrent_kernel_c = self.recurrent_kernel[:, self.units * 2: self.units * 3]
self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3: self.units * 4]
self.recurrent_kernel_m = self.recurrent_kernel[:, self.units * 4:]
if self.use_bias:
self.bias_i = self.bias[:self.units]
self.bias_f = self.bias[self.units: self.units * 2]
self.bias_c = self.bias[self.units * 2: self.units * 3]
self.bias_o = self.bias[self.units * 3: self.units * 4]
self.bias_m = self.bias[self.units * 4:]
else:
self.bias_i = None
self.bias_f = None
self.bias_c = None
self.bias_o = None
self.bias_m = None
self.built = True
def preprocess_input(self, inputs, training=None):
return inputs
def get_constants(self, inputs, training=None):
constants = []
if self.implementation != 0 and 0 < self.dropout < 1:
input_shape = K.int_shape(inputs)
input_dim = input_shape[-1]
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
ones = K.tile(ones, (1, int(input_dim)))
def dropped_inputs():
return K.dropout(ones, self.dropout)
dp_mask = [K.in_train_phase(dropped_inputs,
ones,
training=training) for _ in range(5)]
constants.append(dp_mask)
else:
constants.append([K.cast_to_floatx(1.) for _ in range(5)])
if 0 < self.recurrent_dropout < 1:
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
ones = K.tile(ones, (1, self.units))
def dropped_inputs():
return K.dropout(ones, self.recurrent_dropout)
rec_dp_mask = [K.in_train_phase(dropped_inputs,
ones,
training=training) for _ in range(5)]
constants.append(rec_dp_mask)
else:
constants.append([K.cast_to_floatx(1.) for _ in range(5)])
return constants
def step(self, inputs, states):
h_tm1 = states[0]
c_tm1 = states[1]
dp_mask = states[2]
rec_dp_mask = states[3]
if self.implementation == 2:
z = K.dot(inputs * dp_mask[0], self.kernel)
z += z * K.dot(h_tm1 * rec_dp_mask[0], self.recurrent_kernel) # applies m instead of h_tm1 to z
if self.use_bias:
z = K.bias_add(z, self.bias)
z0 = z[:, :self.units]
z1 = z[:, self.units: 2 * self.units]
z2 = z[:, 2 * self.units: 3 * self.units]
z3 = z[:, 3 * self.units: 4 * self.units]
z4 = z[:, 4 * self.units:] # just elementwise multiplication, no activation functions
i = self.recurrent_activation(z0)
f = self.recurrent_activation(z1)
c = f * c_tm1 + i * self.activation(z2)
o = self.recurrent_activation(z3)
else:
if self.implementation == 1:
x_i = K.dot(inputs * dp_mask[0], self.kernel_i) + self.bias_i
x_f = K.dot(inputs * dp_mask[1], self.kernel_f) + self.bias_f
x_c = K.dot(inputs * dp_mask[2], self.kernel_c) + self.bias_c
x_o = K.dot(inputs * dp_mask[3], self.kernel_o) + self.bias_o
x_m = K.dot(inputs * dp_mask[4], self.kernel_m) + self.bias_m
else:
raise ValueError('Unknown `implementation` mode.')
m = x_m * K.dot(h_tm1 * rec_dp_mask[4], self.recurrent_kernel_m) # elementwise multiplication m
i = self.recurrent_activation(x_i + K.dot(m * rec_dp_mask[0], self.recurrent_kernel_i))
f = self.recurrent_activation(x_f + K.dot(m * rec_dp_mask[1], self.recurrent_kernel_f))
c = f * c_tm1 + i * self.activation(x_c + K.dot(m * rec_dp_mask[2], self.recurrent_kernel_c))
o = self.recurrent_activation(x_o + K.dot(m * rec_dp_mask[3], self.recurrent_kernel_o))
h = o * self.activation(c)
if 0 < self.dropout + self.recurrent_dropout:
h._uses_learning_phase = True
return h, [h, c]
def get_config(self):
config = {'units': self.units,
'activation': activations.serialize(self.activation),
'recurrent_activation': activations.serialize(self.recurrent_activation),
'use_bias': self.use_bias,
'kernel_initializer': initializers.serialize(self.kernel_initializer),
'recurrent_initializer': initializers.serialize(self.recurrent_initializer),
'bias_initializer': initializers.serialize(self.bias_initializer),
'unit_forget_bias': self.unit_forget_bias,
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer),
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'activity_regularizer': regularizers.serialize(self.activity_regularizer),
'kernel_constraint': constraints.serialize(self.kernel_constraint),
'recurrent_constraint': constraints.serialize(self.recurrent_constraint),
'bias_constraint': constraints.serialize(self.bias_constraint),
'dropout': self.dropout,
'recurrent_dropout': self.recurrent_dropout}
base_config = super(MultiplicativeLSTM, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
library(keras)
library(reticulate)
layer_multiplicative_lstm <-function(
object, units, activation = "tanh", recurrent_activation = "hard_sigmoid", use_bias = TRUE,
return_sequences = FALSE, return_state = FALSE, go_backwards = FALSE, stateful = FALSE, unroll = FALSE,
kernel_initializer = "glorot_uniform", recurrent_initializer = "orthogonal", bias_initializer = "zeros",
unit_forget_bias = TRUE, kernel_regularizer = NULL, recurrent_regularizer = NULL, bias_regularizer = NULL,
activity_regularizer = NULL, kernel_constraint = NULL, recurrent_constraint = NULL, bias_constraint = NULL,
dropout = 0.0, recurrent_dropout = 0.0, input_shape = NULL, batch_input_shape = NULL, batch_size = NULL,
dtype = NULL, name = NULL, trainable = NULL, weights = NULL) {
mlstm <- reticulate::import_from_path("multiplicative_lstm")
create_layer(mlstm$MultiplicativeLSTM, object, list(
units = as.integer(units),
activation = activation,
recurrent_activation = recurrent_activation,
use_bias = use_bias,
return_sequences = return_sequences,
return_state = return_state,
go_backwards = go_backwards,
stateful = stateful,
unroll = unroll,
kernel_initializer = kernel_initializer,
recurrent_initializer = recurrent_initializer,
bias_initializer = bias_initializer,
unit_forget_bias = unit_forget_bias,
kernel_regularizer = kernel_regularizer,
recurrent_regularizer = recurrent_regularizer,
bias_regularizer = bias_regularizer,
activity_regularizer = activity_regularizer,
kernel_constraint = kernel_constraint,
recurrent_constraint = recurrent_constraint,
bias_constraint = bias_constraint,
dropout = dropout,
recurrent_dropout = recurrent_dropout,
input_shape = keras:::normalize_shape(input_shape),
batch_input_shape = keras:::normalize_shape(batch_input_shape),
batch_size = keras:::as_nullable_integer(batch_size),
dtype = dtype,
name = name,
trainable = trainable,
weights = weights
))
}
layer_multiplicative_lstm(units = 10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment