Last active
September 12, 2017 03:56
-
-
Save titu1994/638d0a571d368f4fdc8f8ca5eddd25cd to your computer and use it in GitHub Desktop.
Incorrect, partial implementation of SimpleRecurrentUnit from the paper
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
'''Trains an SRU model on the IMDB sentiment classification task. | |
The dataset is actually too small for LSTM to be of any advantage | |
compared to simpler, much faster methods such as TF-IDF + LogReg. | |
Notes: | |
- RNNs are tricky. Choice of batch size is important, | |
choice of loss and optimizer is critical, etc. | |
Some configurations won't converge. | |
- LSTM loss decrease patterns during training can be quite different | |
from what you see with CNNs/MLPs/etc. | |
''' | |
from __future__ import print_function | |
from keras.preprocessing import sequence | |
from keras.models import Model | |
from keras.layers import Dense, Embedding, Input | |
from keras.layers import LSTM | |
from keras.datasets import imdb | |
from sru import SRU | |
max_features = 20000 | |
maxlen = 80 # cut texts after this number of words (among top max_features most common words) | |
batch_size = 128 | |
print('Loading data...') | |
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features) | |
print(len(x_train), 'train sequences') | |
print(len(x_test), 'test sequences') | |
print('Pad sequences (samples x time)') | |
x_train = sequence.pad_sequences(x_train, maxlen=maxlen) | |
x_test = sequence.pad_sequences(x_test, maxlen=maxlen) | |
print('x_train shape:', x_train.shape) | |
print('x_test shape:', x_test.shape) | |
print('Build model...') | |
ip = Input(shape=(80,)) | |
embed = Embedding(max_features, 32, input_shape=(80,))(ip) # batch_input_shape=(32, 80) | |
outputs = SRU(32, dropout=0.2, recurrent_dropout=0.2, implementation=2, unroll=True)(embed) | |
out = Dense(1, activation='sigmoid')(outputs) | |
model = Model(ip, out) | |
model.summary() | |
# try using different optimizers and different optimizer configs | |
model.compile(loss='binary_crossentropy', | |
optimizer='adam', | |
metrics=['accuracy']) | |
print('Train...') | |
model.fit(x_train, y_train, | |
batch_size=batch_size, | |
epochs=100, | |
validation_data=(x_test, y_test)) | |
score, acc = model.evaluate(x_test, y_test, | |
batch_size=batch_size) | |
print('Test score:', score) | |
print('Test accuracy:', acc) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
import numpy as np | |
from keras import backend as K | |
from keras import activations | |
from keras import initializers | |
from keras import regularizers | |
from keras import constraints | |
from keras.engine import Layer | |
from keras.engine import InputSpec | |
from keras.legacy import interfaces | |
from keras.layers import Recurrent | |
from keras.layers.recurrent import _time_distributed_dense | |
class SRU(Recurrent): | |
"""Simple Recurrent Unit - https://arxiv.org/pdf/1709.02755.pdf. | |
# Arguments | |
units: Positive integer, dimensionality of the output space. | |
activation: Activation function to use | |
(see [activations](../activations.md)). | |
If you pass None, no activation is applied | |
(ie. "linear" activation: `a(x) = x`). | |
recurrent_activation: Activation function to use | |
for the recurrent step | |
(see [activations](../activations.md)). | |
use_bias: Boolean, whether the layer uses a bias vector. | |
kernel_initializer: Initializer for the `kernel` weights matrix, | |
used for the linear transformation of the inputs. | |
(see [initializers](../initializers.md)). | |
recurrent_initializer: Initializer for the `recurrent_kernel` | |
weights matrix, | |
used for the linear transformation of the recurrent state. | |
(see [initializers](../initializers.md)). | |
bias_initializer: Initializer for the bias vector | |
(see [initializers](../initializers.md)). | |
unit_forget_bias: Boolean. | |
If True, add 1 to the bias of the forget gate at initialization. | |
Setting it to true will also force `bias_initializer="zeros"`. | |
This is recommended in [Jozefowicz et al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) | |
kernel_regularizer: Regularizer function applied to | |
the `kernel` weights matrix | |
(see [regularizer](../regularizers.md)). | |
recurrent_regularizer: Regularizer function applied to | |
the `recurrent_kernel` weights matrix | |
(see [regularizer](../regularizers.md)). | |
bias_regularizer: Regularizer function applied to the bias vector | |
(see [regularizer](../regularizers.md)). | |
activity_regularizer: Regularizer function applied to | |
the output of the layer (its "activation"). | |
(see [regularizer](../regularizers.md)). | |
kernel_constraint: Constraint function applied to | |
the `kernel` weights matrix | |
(see [constraints](../constraints.md)). | |
recurrent_constraint: Constraint function applied to | |
the `recurrent_kernel` weights matrix | |
(see [constraints](../constraints.md)). | |
bias_constraint: Constraint function applied to the bias vector | |
(see [constraints](../constraints.md)). | |
dropout: Float between 0 and 1. | |
Fraction of the units to drop for | |
the linear transformation of the inputs. | |
recurrent_dropout: Float between 0 and 1. | |
Fraction of the units to drop for | |
the linear transformation of the recurrent state. | |
# References | |
- [Long short-term memory](http://www.bioinf.jku.at/publications/older/2604.pdf) (original 1997 paper) | |
- [Learning to forget: Continual prediction with LSTM](http://www.mitpressjournals.org/doi/pdf/10.1162/089976600300015015) | |
- [Supervised sequence labeling with recurrent neural networks](http://www.cs.toronto.edu/~graves/preprint.pdf) | |
- [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287) | |
""" | |
@interfaces.legacy_recurrent_support | |
def __init__(self, units, | |
activation='tanh', | |
recurrent_activation='sigmoid', | |
use_bias=True, | |
kernel_initializer='glorot_uniform', | |
recurrent_initializer='orthogonal', | |
bias_initializer='zeros', | |
unit_forget_bias=True, | |
kernel_regularizer=None, | |
recurrent_regularizer=None, | |
bias_regularizer=None, | |
activity_regularizer=None, | |
kernel_constraint=None, | |
recurrent_constraint=None, | |
bias_constraint=None, | |
dropout=0., | |
recurrent_dropout=0., | |
**kwargs): | |
super(SRU, self).__init__(**kwargs) | |
self.units = units | |
self.activation = activations.get(activation) | |
self.recurrent_activation = activations.get(recurrent_activation) | |
self.use_bias = use_bias | |
self.kernel_initializer = initializers.get(kernel_initializer) | |
self.recurrent_initializer = initializers.get(recurrent_initializer) | |
self.bias_initializer = initializers.get(bias_initializer) | |
self.unit_forget_bias = unit_forget_bias | |
self.kernel_regularizer = regularizers.get(kernel_regularizer) | |
self.recurrent_regularizer = regularizers.get(recurrent_regularizer) | |
self.bias_regularizer = regularizers.get(bias_regularizer) | |
self.activity_regularizer = regularizers.get(activity_regularizer) | |
self.kernel_constraint = constraints.get(kernel_constraint) | |
self.recurrent_constraint = constraints.get(recurrent_constraint) | |
self.bias_constraint = constraints.get(bias_constraint) | |
self.dropout = min(1., max(0., dropout)) | |
self.recurrent_dropout = min(1., max(0., recurrent_dropout)) | |
self.state_spec = [InputSpec(shape=(None, self.units)), | |
InputSpec(shape=(None, self.units))] | |
def build(self, input_shape): | |
if isinstance(input_shape, list): | |
input_shape = input_shape[0] | |
batch_size = input_shape[0] if self.stateful else None | |
self.input_dim = input_shape[2] | |
self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim)) # (timesteps, batchsize, inputdim) | |
self.states = [None, None] | |
if self.stateful: | |
self.reset_states() | |
# There may be cases where input dim does not match output units. | |
# In such a case, the code in pytorch adds another set of weights | |
# to bring the intermediate shape to the correct dimentions. | |
# Here, I call it the `u` kernel, though it doesnt have any specific | |
# implementation yet. | |
self.kernel_dim = 3 if self.input_dim == self.units else 4 | |
self.kernel = self.add_weight(shape=(self.input_dim, self.units * self.kernel_dim), | |
name='kernel', | |
initializer=self.kernel_initializer, | |
regularizer=self.kernel_regularizer, | |
constraint=self.kernel_constraint) | |
if self.use_bias: | |
if self.unit_forget_bias: | |
def bias_initializer(shape, *args, **kwargs): | |
return K.concatenate([ | |
self.bias_initializer((self.units,), *args, **kwargs), | |
initializers.Ones()((self.units,), *args, **kwargs), | |
self.bias_initializer((self.units,), *args, **kwargs), | |
]) | |
else: | |
bias_initializer = self.bias_initializer | |
self.bias = self.add_weight(shape=(self.units * self.kernel_dim,), | |
name='bias', | |
initializer=bias_initializer, | |
regularizer=self.bias_regularizer, | |
constraint=self.bias_constraint) | |
else: | |
self.bias = None | |
self.kernel_w = self.kernel[:, :self.units] | |
self.kernel_f = self.kernel[:, self.units: self.units * 2] | |
self.kernel_r = self.kernel[:, self.units * 2: self.units * 3] | |
if self.kernel_dim == 4: | |
self.kernel_u = self.kernel[:, self.units * 3: self.units * 4] | |
else: | |
self.kernel_u = None | |
if self.use_bias: | |
self.bias_w = self.bias[:self.units] | |
self.bias_f = self.bias[self.units: self.units * 2] | |
self.bias_r = self.bias[self.units * 2: self.units * 3] | |
if self.kernel_dim == 4: | |
self.bias_u = self.bias[self.units * 3: self.units * 4] | |
else: | |
self.bias_w = None | |
self.bias_f = None | |
self.bias_r = None | |
self.bias_u = None | |
self.built = True | |
def preprocess_input(self, inputs, training=None): | |
if self.implementation == 0: | |
input_shape = K.int_shape(inputs) | |
input_dim = input_shape[2] | |
timesteps = input_shape[1] | |
x_w = _time_distributed_dense(inputs, self.kernel_w, self.bias_w, | |
self.dropout, input_dim, self.units, | |
timesteps, training=training) | |
x_f = _time_distributed_dense(inputs, self.kernel_f, self.bias_f, | |
self.dropout, input_dim, self.units, | |
timesteps, training=training) | |
x_r = _time_distributed_dense(inputs, self.kernel_r, self.bias_r, | |
self.dropout, input_dim, self.units, | |
timesteps, training=training) | |
if self.kernel_dim == 4: | |
x_u = _time_distributed_dense(inputs, self.kernel_u, self.bias_u, | |
self.dropout, input_dim, self.units, | |
timesteps, training=training) | |
return K.concatenate([x_w, x_f, x_r, x_u], axis=2) | |
else: | |
return K.concatenate([x_w, x_f, x_r], axis=2) | |
else: | |
return inputs | |
def get_constants(self, inputs, training=None): | |
constants = [] | |
if self.implementation != 0 and 0 < self.dropout < 1: | |
input_shape = K.int_shape(inputs) # (timesteps, batchsize, inputdim) | |
input_dim = input_shape[-1] | |
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) | |
ones = K.tile(ones, (1, int(input_dim))) | |
def dropped_inputs(): | |
return K.dropout(ones, self.dropout) | |
dp_mask = [K.in_train_phase(dropped_inputs, | |
ones, | |
training=training) for _ in range(4)] | |
constants.append(dp_mask) | |
else: | |
constants.append([K.cast_to_floatx(1.) for _ in range(4)]) | |
constants.append(inputs) # append the inputs so that we can utilize them in x_t | |
self.time_step = 0 | |
return constants | |
def step(self, inputs, states): | |
h_tm1 = states[0] | |
c_tm1 = states[1] | |
dp_mask = states[2] | |
x_inputs = states[3] | |
# To see correct batch shapes, set batch_input_shape to some value, | |
# otherwise the None can be confusing to interpret. | |
print("X inputs shape : ", K.int_shape(x_inputs)) | |
print('h_tm1 shape: ', K.int_shape(h_tm1)) | |
print('c_tm1 shape: ', K.int_shape(c_tm1)) | |
if self.implementation == 2: | |
z = K.dot(inputs * dp_mask[0], self.kernel) | |
if self.use_bias: | |
z = K.bias_add(z, self.bias) | |
z0 = z[:, :self.units] | |
z1 = z[:, self.units: 2 * self.units] | |
z2 = z[:, 2 * self.units: 3 * self.units] | |
f = self.recurrent_activation(z1) | |
r = self.recurrent_activation(z2) | |
# print("W shape : ", K.int_shape(z0)) | |
# print("F shape : ", K.int_shape(f)) | |
# print("R shape : ", K.int_shape(r)) | |
c = f * c_tm1 + (1 - f) * z0 | |
h = r * self.activation(c) + (1 - r) * x_inputs[:, self.time_step, :] # x_inputs should not have 0 index | |
else: | |
if self.implementation == 0: | |
x_w = inputs[:, :self.units] | |
x_f = inputs[:, self.units: 2 * self.units] | |
x_r = inputs[:, 2 * self.units: 3 * self.units] | |
elif self.implementation == 1: | |
x_w = K.dot(inputs * dp_mask[0], self.kernel_w) + self.bias_w | |
x_f = K.dot(inputs * dp_mask[1], self.kernel_f) + self.bias_f | |
x_r = K.dot(inputs * dp_mask[2], self.kernel_r) + self.bias_r | |
else: | |
raise ValueError('Unknown `implementation` mode.') | |
w = x_w | |
f = self.recurrent_activation(x_f) | |
r = self.recurrent_activation(x_r) | |
print("W shape : ", K.int_shape(w)) | |
print("F shape : ", K.int_shape(f)) | |
print("R shape : ", K.int_shape(r)) | |
c = f * c_tm1 + (1 - f) * w | |
h = r * self.activation(c) + (1 - r) * x_inputs[:, self.time_step, :] # x_inputs should not have 0 index | |
self.time_step += 1 | |
print('timestep : ', self.time_step) | |
if 0 < self.dropout + self.recurrent_dropout: | |
h._uses_learning_phase = True | |
return h, [h, c] | |
def get_config(self): | |
config = {'units': self.units, | |
'activation': activations.serialize(self.activation), | |
'recurrent_activation': activations.serialize(self.recurrent_activation), | |
'use_bias': self.use_bias, | |
'kernel_initializer': initializers.serialize(self.kernel_initializer), | |
'recurrent_initializer': initializers.serialize(self.recurrent_initializer), | |
'bias_initializer': initializers.serialize(self.bias_initializer), | |
'unit_forget_bias': self.unit_forget_bias, | |
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), | |
'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), | |
'bias_regularizer': regularizers.serialize(self.bias_regularizer), | |
'activity_regularizer': regularizers.serialize(self.activity_regularizer), | |
'kernel_constraint': constraints.serialize(self.kernel_constraint), | |
'recurrent_constraint': constraints.serialize(self.recurrent_constraint), | |
'bias_constraint': constraints.serialize(self.bias_constraint), | |
'dropout': self.dropout, | |
'recurrent_dropout': self.recurrent_dropout} | |
base_config = super(SRU, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Speed test for SRU model. Currently only works correctly when unrolled. Speed of 7 seconds per epoch on IMDB (but it overfits too rapidly). Comparable to 7- 8 seconds per epoch for imdb CNN script in Keras examples