Created
October 26, 2018 13:15
-
-
Save jjallaire/92740db8588ce5e62bc8863e487c2134 to your computer and use it in GitHub Desktop.
Custom Multiplicative LSTM Layer for R Keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
import numpy as np | |
__all__ = ['MultiplicativeLSTM'] | |
from keras import backend as K | |
from keras import activations | |
from keras import initializers | |
from keras import regularizers | |
from keras import constraints | |
from keras.engine import Layer | |
from keras.engine import InputSpec | |
from keras.legacy import interfaces | |
from keras.layers import Recurrent | |
class MultiplicativeLSTM(Recurrent): | |
"""Multiplicative Long-Short Term Memory unit - https://arxiv.org/pdf/1609.07959.pdf | |
# Arguments | |
units: Positive integer, dimensionality of the output space. | |
activation: Activation function to use | |
(see [activations](../activations.md)). | |
If you pass None, no activation is applied | |
(ie. "linear" activation: `a(x) = x`). | |
recurrent_activation: Activation function to use | |
for the recurrent step | |
(see [activations](../activations.md)). | |
use_bias: Boolean, whether the layer uses a bias vector. | |
kernel_initializer: Initializer for the `kernel` weights matrix, | |
used for the linear transformation of the inputs. | |
(see [initializers](../initializers.md)). | |
recurrent_initializer: Initializer for the `recurrent_kernel` | |
weights matrix, | |
used for the linear transformation of the recurrent state. | |
(see [initializers](../initializers.md)). | |
bias_initializer: Initializer for the bias vector | |
(see [initializers](../initializers.md)). | |
unit_forget_bias: Boolean. | |
If True, add 1 to the bias of the forget gate at initialization. | |
Setting it to true will also force `bias_initializer="zeros"`. | |
This is recommended in [Jozefowicz et al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) | |
kernel_regularizer: Regularizer function applied to | |
the `kernel` weights matrix | |
(see [regularizer](../regularizers.md)). | |
recurrent_regularizer: Regularizer function applied to | |
the `recurrent_kernel` weights matrix | |
(see [regularizer](../regularizers.md)). | |
bias_regularizer: Regularizer function applied to the bias vector | |
(see [regularizer](../regularizers.md)). | |
activity_regularizer: Regularizer function applied to | |
the output of the layer (its "activation"). | |
(see [regularizer](../regularizers.md)). | |
kernel_constraint: Constraint function applied to | |
the `kernel` weights matrix | |
(see [constraints](../constraints.md)). | |
recurrent_constraint: Constraint function applied to | |
the `recurrent_kernel` weights matrix | |
(see [constraints](../constraints.md)). | |
bias_constraint: Constraint function applied to the bias vector | |
(see [constraints](../constraints.md)). | |
dropout: Float between 0 and 1. | |
Fraction of the units to drop for | |
the linear transformation of the inputs. | |
recurrent_dropout: Float between 0 and 1. | |
Fraction of the units to drop for | |
the linear transformation of the recurrent state. | |
# References | |
- [Long short-term memory](http://www.bioinf.jku.at/publications/older/2604.pdf) (original 1997 paper) | |
- [Learning to forget: Continual prediction with MultiplicativeLSTM](http://www.mitpressjournals.org/doi/pdf/10.1162/089976600300015015) | |
- [Supervised sequence labeling with recurrent neural networks](http://www.cs.toronto.edu/~graves/preprint.pdf) | |
- [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287) | |
""" | |
@interfaces.legacy_recurrent_support | |
def __init__(self, units, | |
activation='tanh', | |
recurrent_activation='hard_sigmoid', | |
use_bias=True, | |
kernel_initializer='glorot_uniform', | |
recurrent_initializer='orthogonal', | |
bias_initializer='zeros', | |
unit_forget_bias=True, | |
kernel_regularizer=None, | |
recurrent_regularizer=None, | |
bias_regularizer=None, | |
activity_regularizer=None, | |
kernel_constraint=None, | |
recurrent_constraint=None, | |
bias_constraint=None, | |
dropout=0., | |
recurrent_dropout=0., | |
implementation=1, | |
**kwargs): | |
super(MultiplicativeLSTM, self).__init__(**kwargs) | |
self.units = units | |
self.activation = activations.get(activation) | |
self.recurrent_activation = activations.get(recurrent_activation) | |
self.use_bias = use_bias | |
self.kernel_initializer = initializers.get(kernel_initializer) | |
self.recurrent_initializer = initializers.get(recurrent_initializer) | |
self.bias_initializer = initializers.get(bias_initializer) | |
self.unit_forget_bias = unit_forget_bias | |
self.kernel_regularizer = regularizers.get(kernel_regularizer) | |
self.recurrent_regularizer = regularizers.get(recurrent_regularizer) | |
self.bias_regularizer = regularizers.get(bias_regularizer) | |
self.activity_regularizer = regularizers.get(activity_regularizer) | |
self.kernel_constraint = constraints.get(kernel_constraint) | |
self.recurrent_constraint = constraints.get(recurrent_constraint) | |
self.bias_constraint = constraints.get(bias_constraint) | |
self.dropout = min(1., max(0., dropout)) | |
self.recurrent_dropout = min(1., max(0., recurrent_dropout)) | |
self.state_spec = [InputSpec(shape=(None, self.units)), | |
InputSpec(shape=(None, self.units))] | |
self.state_size = (self.units, self.units) | |
self.implementation = implementation | |
def build(self, input_shape): | |
if isinstance(input_shape, list): | |
input_shape = input_shape[0] | |
batch_size = input_shape[0] if self.stateful else None | |
self.input_dim = input_shape[2] | |
self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim)) | |
self.states = [None, None] | |
if self.stateful: | |
self.reset_states() | |
self.kernel = self.add_weight(shape=(self.input_dim, self.units * 5), | |
name='kernel', | |
initializer=self.kernel_initializer, | |
regularizer=self.kernel_regularizer, | |
constraint=self.kernel_constraint) | |
self.recurrent_kernel = self.add_weight( | |
shape=(self.units, self.units * 5), | |
name='recurrent_kernel', | |
initializer=self.recurrent_initializer, | |
regularizer=self.recurrent_regularizer, | |
constraint=self.recurrent_constraint) | |
if self.use_bias: | |
if self.unit_forget_bias: | |
def bias_initializer(shape, *args, **kwargs): | |
return K.concatenate([ | |
self.bias_initializer((self.units,), *args, **kwargs), | |
initializers.Ones()((self.units,), *args, **kwargs), | |
self.bias_initializer((self.units * 3,), *args, **kwargs), | |
]) | |
else: | |
bias_initializer = self.bias_initializer | |
self.bias = self.add_weight(shape=(self.units * 5,), | |
name='bias', | |
initializer=bias_initializer, | |
regularizer=self.bias_regularizer, | |
constraint=self.bias_constraint) | |
else: | |
self.bias = None | |
self.kernel_i = self.kernel[:, :self.units] | |
self.kernel_f = self.kernel[:, self.units: self.units * 2] | |
self.kernel_c = self.kernel[:, self.units * 2: self.units * 3] | |
self.kernel_o = self.kernel[:, self.units * 3: self.units * 4] | |
self.kernel_m = self.kernel[:, self.units * 4:] | |
self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units] | |
self.recurrent_kernel_f = self.recurrent_kernel[:, self.units: self.units * 2] | |
self.recurrent_kernel_c = self.recurrent_kernel[:, self.units * 2: self.units * 3] | |
self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3: self.units * 4] | |
self.recurrent_kernel_m = self.recurrent_kernel[:, self.units * 4:] | |
if self.use_bias: | |
self.bias_i = self.bias[:self.units] | |
self.bias_f = self.bias[self.units: self.units * 2] | |
self.bias_c = self.bias[self.units * 2: self.units * 3] | |
self.bias_o = self.bias[self.units * 3: self.units * 4] | |
self.bias_m = self.bias[self.units * 4:] | |
else: | |
self.bias_i = None | |
self.bias_f = None | |
self.bias_c = None | |
self.bias_o = None | |
self.bias_m = None | |
self.built = True | |
def preprocess_input(self, inputs, training=None): | |
return inputs | |
def get_constants(self, inputs, training=None): | |
constants = [] | |
if self.implementation != 0 and 0 < self.dropout < 1: | |
input_shape = K.int_shape(inputs) | |
input_dim = input_shape[-1] | |
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) | |
ones = K.tile(ones, (1, int(input_dim))) | |
def dropped_inputs(): | |
return K.dropout(ones, self.dropout) | |
dp_mask = [K.in_train_phase(dropped_inputs, | |
ones, | |
training=training) for _ in range(5)] | |
constants.append(dp_mask) | |
else: | |
constants.append([K.cast_to_floatx(1.) for _ in range(5)]) | |
if 0 < self.recurrent_dropout < 1: | |
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) | |
ones = K.tile(ones, (1, self.units)) | |
def dropped_inputs(): | |
return K.dropout(ones, self.recurrent_dropout) | |
rec_dp_mask = [K.in_train_phase(dropped_inputs, | |
ones, | |
training=training) for _ in range(5)] | |
constants.append(rec_dp_mask) | |
else: | |
constants.append([K.cast_to_floatx(1.) for _ in range(5)]) | |
return constants | |
def step(self, inputs, states): | |
h_tm1 = states[0] | |
c_tm1 = states[1] | |
dp_mask = states[2] | |
rec_dp_mask = states[3] | |
if self.implementation == 2: | |
z = K.dot(inputs * dp_mask[0], self.kernel) | |
z += z * K.dot(h_tm1 * rec_dp_mask[0], self.recurrent_kernel) # applies m instead of h_tm1 to z | |
if self.use_bias: | |
z = K.bias_add(z, self.bias) | |
z0 = z[:, :self.units] | |
z1 = z[:, self.units: 2 * self.units] | |
z2 = z[:, 2 * self.units: 3 * self.units] | |
z3 = z[:, 3 * self.units: 4 * self.units] | |
z4 = z[:, 4 * self.units:] # just elementwise multiplication, no activation functions | |
i = self.recurrent_activation(z0) | |
f = self.recurrent_activation(z1) | |
c = f * c_tm1 + i * self.activation(z2) | |
o = self.recurrent_activation(z3) | |
else: | |
if self.implementation == 1: | |
x_i = K.dot(inputs * dp_mask[0], self.kernel_i) + self.bias_i | |
x_f = K.dot(inputs * dp_mask[1], self.kernel_f) + self.bias_f | |
x_c = K.dot(inputs * dp_mask[2], self.kernel_c) + self.bias_c | |
x_o = K.dot(inputs * dp_mask[3], self.kernel_o) + self.bias_o | |
x_m = K.dot(inputs * dp_mask[4], self.kernel_m) + self.bias_m | |
else: | |
raise ValueError('Unknown `implementation` mode.') | |
m = x_m * K.dot(h_tm1 * rec_dp_mask[4], self.recurrent_kernel_m) # elementwise multiplication m | |
i = self.recurrent_activation(x_i + K.dot(m * rec_dp_mask[0], self.recurrent_kernel_i)) | |
f = self.recurrent_activation(x_f + K.dot(m * rec_dp_mask[1], self.recurrent_kernel_f)) | |
c = f * c_tm1 + i * self.activation(x_c + K.dot(m * rec_dp_mask[2], self.recurrent_kernel_c)) | |
o = self.recurrent_activation(x_o + K.dot(m * rec_dp_mask[3], self.recurrent_kernel_o)) | |
h = o * self.activation(c) | |
if 0 < self.dropout + self.recurrent_dropout: | |
h._uses_learning_phase = True | |
return h, [h, c] | |
def get_config(self): | |
config = {'units': self.units, | |
'activation': activations.serialize(self.activation), | |
'recurrent_activation': activations.serialize(self.recurrent_activation), | |
'use_bias': self.use_bias, | |
'kernel_initializer': initializers.serialize(self.kernel_initializer), | |
'recurrent_initializer': initializers.serialize(self.recurrent_initializer), | |
'bias_initializer': initializers.serialize(self.bias_initializer), | |
'unit_forget_bias': self.unit_forget_bias, | |
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), | |
'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), | |
'bias_regularizer': regularizers.serialize(self.bias_regularizer), | |
'activity_regularizer': regularizers.serialize(self.activity_regularizer), | |
'kernel_constraint': constraints.serialize(self.kernel_constraint), | |
'recurrent_constraint': constraints.serialize(self.recurrent_constraint), | |
'bias_constraint': constraints.serialize(self.bias_constraint), | |
'dropout': self.dropout, | |
'recurrent_dropout': self.recurrent_dropout} | |
base_config = super(MultiplicativeLSTM, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(keras) | |
library(reticulate) | |
layer_multiplicative_lstm <-function( | |
object, units, activation = "tanh", recurrent_activation = "hard_sigmoid", use_bias = TRUE, | |
return_sequences = FALSE, return_state = FALSE, go_backwards = FALSE, stateful = FALSE, unroll = FALSE, | |
kernel_initializer = "glorot_uniform", recurrent_initializer = "orthogonal", bias_initializer = "zeros", | |
unit_forget_bias = TRUE, kernel_regularizer = NULL, recurrent_regularizer = NULL, bias_regularizer = NULL, | |
activity_regularizer = NULL, kernel_constraint = NULL, recurrent_constraint = NULL, bias_constraint = NULL, | |
dropout = 0.0, recurrent_dropout = 0.0, input_shape = NULL, batch_input_shape = NULL, batch_size = NULL, | |
dtype = NULL, name = NULL, trainable = NULL, weights = NULL) { | |
mlstm <- reticulate::import_from_path("multiplicative_lstm") | |
create_layer(mlstm$MultiplicativeLSTM, object, list( | |
units = as.integer(units), | |
activation = activation, | |
recurrent_activation = recurrent_activation, | |
use_bias = use_bias, | |
return_sequences = return_sequences, | |
return_state = return_state, | |
go_backwards = go_backwards, | |
stateful = stateful, | |
unroll = unroll, | |
kernel_initializer = kernel_initializer, | |
recurrent_initializer = recurrent_initializer, | |
bias_initializer = bias_initializer, | |
unit_forget_bias = unit_forget_bias, | |
kernel_regularizer = kernel_regularizer, | |
recurrent_regularizer = recurrent_regularizer, | |
bias_regularizer = bias_regularizer, | |
activity_regularizer = activity_regularizer, | |
kernel_constraint = kernel_constraint, | |
recurrent_constraint = recurrent_constraint, | |
bias_constraint = bias_constraint, | |
dropout = dropout, | |
recurrent_dropout = recurrent_dropout, | |
input_shape = keras:::normalize_shape(input_shape), | |
batch_input_shape = keras:::normalize_shape(batch_input_shape), | |
batch_size = keras:::as_nullable_integer(batch_size), | |
dtype = dtype, | |
name = name, | |
trainable = trainable, | |
weights = weights | |
)) | |
} | |
layer_multiplicative_lstm(units = 10) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment