Last active
September 3, 2016 13:51
-
-
Save farizrahman4u/02092825699d4d6f5bb216ac86ebb038 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers import Recurrent | |
from keras.models import Sequential | |
from keras import backend as K | |
def _isRNN(layer): | |
return issubclass(layer.__class__, Recurrent) | |
def _zeros(shape): | |
shape = [i if i else 2 for i in shape] | |
return K.zeros(shape) | |
class DepthFirstRecurrentContainer(Recurrent, Sequential): | |
def __init__(self, *args, **kwargs): | |
Recurrent.__init__(self, *args, **kwargs) | |
Sequential.__init__(self) | |
@property | |
def input_shape(self): | |
return Sequential.input_shape(self) | |
@property | |
def output_shape(self): | |
shape = Sequential.output_shape(self) | |
if self.return_sequences: | |
shape = (shape[0], self.input_shape[1]) + shape[1:] | |
return shape | |
def get_output_shape_for(self, input_shape): | |
shape = Sequential.get_output_shape_for(self, input_shape) | |
if self.return_sequences: | |
shape = (shape[0], self.input_shape[1]) + shape[1:] | |
return shape | |
def add(self, layer): | |
if _isRNN(layer): | |
layer.return_sequences = False | |
layer.consume_less = 'mem' | |
if len(self.layers > 0) and not _isRNN(self.layers[-1]): | |
input_length = self.input_shape[0] | |
if not input_length: | |
input_length = 1 | |
dummy_layer = Lambda(lambda x: K.tile(K.expand_dims(x, 1), [1, input_length] + [1] * (K.ndim(x) - 1)), output_shape=lambda s: (s[0], input_length) + s[1:]) | |
dummy_layer.dummy = True | |
Sequential.add(self, dummy_layer) | |
Sequential.add(self, layer) | |
def step(self, x, states): | |
nb_states = [] | |
nb_constants = [] | |
for layer in self.layers: | |
if _isRNN(layer): | |
nb_states += [len(layer.states)] | |
if not hasattr(layer, 'nb_constants'): | |
layer.nb_constants = len(layer.get_constants(_zeros(layer.input_shape))) | |
nb_constants += [layer.nb_constants] | |
rnn_index = 0 | |
for layer in self.layers: | |
if hasattr(layer, 'dummy'): | |
continue | |
if _isRNN(layer): | |
states_idx = sum(nb_states[:rnn_index]) | |
consts_idx = states_idx + sum(nb_states[rnn_index:]) + sum(nb_constants[:rnn_index]) | |
required_states = states[states_idx : nb_states[rnn_index]] + states[consts_idx : nb_constants[rnn_index]] | |
x, new_states = layer.step(x, states) | |
states[states_idx : nb_states[rnn_index]] = new_states | |
rnn_index += 1 | |
else: | |
x = layer.call(x) | |
return x, states[:sum(nb_states)] | |
def get_initial_states(self, x): | |
initial_states = [] | |
for layer in self.layers: | |
if _isRNN(layer): | |
initial_states += layer.get_initial_states(_zeros(layer.input_shape)) | |
return initial_states | |
def get_constants(self, x): | |
constants = [] | |
for layer in self.layers: | |
if _isRNN(layer): | |
consts = layer.get_constants(_zeros(layer.input_shape)) | |
layer.nb_constants = len(consts) | |
return constants |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment