Instantly share code, notes, and snippets.
Forked from mbollmann/hidden_state_lstm.py
Created
May 9, 2017 14:03
-
Star
(0)
0
You must be signed in to star a gist -
Fork
(0)
0
You must be signed in to fork a gist
-
Save Natural209X/768aef1ee1dd4543f18d2c8caf45f388 to your computer and use it in GitHub Desktop.
Keras LSTM that inputs/outputs its internal states, e.g. for hidden state transfer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras import backend as K | |
from keras.layers.recurrent import LSTM | |
class HiddenStateLSTM(LSTM): | |
"""LSTM with input/output capabilities for its hidden state. | |
This layer behaves just like an LSTM, except that it accepts further inputs | |
to be used as its initial states, and returns additional outputs, | |
representing the layer's final states. | |
See Also: | |
https://github.com/fchollet/keras/issues/2995 | |
""" | |
def build(self, input_shape): | |
if isinstance(input_shape, list) and len(input_shape) > 1: | |
input_shape, *hidden_shapes = input_shape | |
for shape in hidden_shapes: | |
assert shape[0] == input_shape[0] | |
assert shape[-1] == self.output_dim | |
super().build(input_shape) | |
def call(self, x, mask=None): | |
# input shape: (nb_samples, time (padded with zeros), input_dim) | |
input_shape = self.input_spec[0].shape | |
if isinstance(x, (tuple, list)): | |
x, *custom_initial = x | |
else: | |
custom_initial = None | |
if K._BACKEND == 'tensorflow': | |
if not input_shape[1]: | |
raise Exception('When using TensorFlow, you should define ' | |
'explicitly the number of timesteps of ' | |
'your sequences.\n' | |
'If your first layer is an Embedding, ' | |
'make sure to pass it an "input_length" ' | |
'argument. Otherwise, make sure ' | |
'the first layer has ' | |
'an "input_shape" or "batch_input_shape" ' | |
'argument, including the time axis. ' | |
'Found input shape at layer ' + self.name + | |
': ' + str(input_shape)) | |
if self.stateful and custom_initial: | |
raise Exception(('Initial states should not be specified ' | |
'for stateful LSTMs, since they would overwrite ' | |
'the memorized states.')) | |
elif custom_initial: | |
initial_states = custom_initial | |
elif self.stateful: | |
initial_states = self.states | |
else: | |
initial_states = self.get_initial_states(x) | |
constants = self.get_constants(x) | |
preprocessed_input = self.preprocess_input(x) | |
# only use the main input mask | |
if isinstance(mask, list): | |
mask = mask[0] | |
last_output, outputs, states = K.rnn(self.step, preprocessed_input, | |
initial_states, | |
go_backwards=self.go_backwards, | |
mask=mask, | |
constants=constants, | |
unroll=self.unroll, | |
input_length=input_shape[1]) | |
if self.stateful: | |
self.updates = [] | |
for i in range(len(states)): | |
self.updates.append((self.states[i], states[i])) | |
if self.return_sequences: | |
return [outputs] + states | |
else: | |
return [last_output] + states | |
def get_output_shape_for(self, input_shape): | |
if isinstance(input_shape, list) and len(input_shape) > 1: | |
input_shape = input_shape[0] | |
if self.return_sequences: | |
output_shape = (input_shape[0], input_shape[1], self.output_dim) | |
else: | |
output_shape = (input_shape[0], self.output_dim) | |
state_output = (input_shape[0], self.output_dim) | |
return [output_shape, state_output, state_output] | |
def compute_mask(self, input, mask): | |
if isinstance(mask, list) and len(mask) > 1: | |
return mask | |
elif self.return_sequences: | |
return [mask, None, None] | |
else: | |
return [None] * 3 | |
##################### | |
### Usage example ### | |
##################### | |
if __name__ == '__main__': | |
from keras.layers import Input, Embedding, Dense, TimeDistributed, Activation | |
from keras.models import Model | |
### build encoder | |
enc_input = Input(shape=(24,), dtype='int32', name='encoder_input') | |
enc_layer = Embedding(128, 64, mask_zero=True)(enc_input) | |
enc_layer, *hidden = HiddenStateLSTM(64, dropout_W=0.5, return_sequences=False)(enc_layer) | |
### build decoder | |
dec_input = Input(shape=(24,), dtype='int32', name='decoder_input') | |
dec_layer = Embedding(128, 64, mask_zero=True)(dec_input) | |
dec_layer, _, _ = HiddenStateLSTM(64, dropout_W=0.5, return_sequences=True)([dec_layer] + hidden) | |
dec_layer = TimeDistributed(Dense(128))(dec_layer) | |
dec_output = Activation('softmax', name='decoder_output')(dec_layer) | |
### build model | |
model = Model(input=[enc_input, dec_input], output=dec_output) | |
model.compile(optimizer='adam', loss='categorical_crossentropy') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment