Last active
December 19, 2017 22:12
-
-
Save mpariente/24f1e3b43e5a999acaa97a42df9a4ed9 to your computer and use it in GitHub Desktop.
Creating a child of Bidirectional to handle return_states
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# In answer to this issue | |
# https://github.com/keras-team/keras/issues/8823 | |
from keras import backend as K | |
from keras.layers import Bidirectional | |
from keras.utils.generic_utils import has_arg | |
from keras.layers import Input, LSTM | |
class MyBidirectional(Bidirectional): | |
def __init__(self, layer, merge_mode='concat', weights=None, **kwargs): | |
super(MyBidirectional, self).__init__(layer, **kwargs) | |
self.return_state = layer.return_state | |
def compute_output_shape(self, input_shape): | |
if hasattr(self.layer.cell.state_size, '__len__'): | |
output_dim = self.layer.cell.state_size[0] | |
else: | |
output_dim = self.layer.cell.state_size | |
if self.merge_mode in ['sum', 'ave', 'mul']: | |
output_shape = self.forward_layer.compute_output_shape(input_shape) | |
elif self.merge_mode == 'concat': | |
shape = list(self.forward_layer.compute_output_shape(input_shape)) | |
shape[-1] *= 2 | |
output_shape = tuple(shape) | |
elif self.merge_mode is None: | |
output_shape = [self.forward_layer.compute_output_shape(input_shape)] * 2 | |
if self.return_state: | |
state_shape = [(input_shape[0], output_dim) for _ in range(4)] | |
return [output_shape] + state_shape | |
else: | |
return output_shape | |
def call(self, inputs, training=None, mask=None): | |
kwargs = {} | |
if has_arg(self.layer.call, 'training'): | |
kwargs['training'] = training | |
if has_arg(self.layer.call, 'mask'): | |
kwargs['mask'] = mask | |
y = self.forward_layer.call(inputs, **kwargs) | |
y_rev = self.backward_layer.call(inputs, **kwargs) | |
if self.return_state: | |
y, states_h, states_c = y | |
y_rev, states_h_rev, states_c_rev = y_rev | |
if self.return_sequences: | |
y_rev = K.reverse(y_rev, 1) | |
if self.merge_mode == 'concat': | |
output = K.concatenate([y, y_rev]) | |
elif self.merge_mode == 'sum': | |
output = y + y_rev | |
elif self.merge_mode == 'ave': | |
output = (y + y_rev) / 2 | |
elif self.merge_mode == 'mul': | |
output = y * y_rev | |
elif self.merge_mode is None: | |
output = [y, y_rev] | |
# Properly set learning phase | |
if (getattr(y, '_uses_learning_phase', False) or | |
getattr(y_rev, '_uses_learning_phase', False)): | |
if self.merge_mode is None: | |
for out in output: | |
out._uses_learning_phase = True | |
else: | |
output._uses_learning_phase = True | |
if self.return_state: | |
states = [states_h, states_h_rev, states_c, states_c_rev] | |
if not isinstance(states, (list, tuple)): | |
states = [states] | |
else: | |
states = list(states) | |
return [output] + states | |
else: | |
return output | |
# Just random values to fit with your code | |
num_encoder_tokens = 10 | |
latent_dim = 30 | |
encoder_inputs = Input(shape=(None, num_encoder_tokens)) | |
encoder = LSTM(latent_dim, return_state=True) | |
encoder_outputs, state_h, state_c = encoder(encoder_inputs) | |
encoder_inputs = Input(shape=(None, num_encoder_tokens)) | |
encoder = MyBidirectional(LSTM(latent_dim, return_state=True),merge_mode="mul") | |
outputs_and_states = encoder(encoder_inputs) | |
#encoder_outputs, states_h, states_h_rev, states_c, states_c_rev = encoder(encoder_inputs) | |
outputs = outputs_and_states[0] | |
states = outputs_and_states[1:] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment