Created
March 27, 2018 11:24
-
-
Save nuric/62addaa01de0d75a9a608d2c6978c771 to your computer and use it in GitHub Desktop.
A wrapper for Keras GRU that skips timesteps if inputs for that timestep are all zeros.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""ZeroGRU module.""" | |
import keras.backend as K | |
import keras.layers as L | |
class ZeroGRUCell(L.GRUCell): | |
"""GRU Cell that skips timestep if inputs is zero as well.""" | |
def call(self, inputs, states, training=None): | |
"""Step function of the cell.""" | |
h_tm1 = states[0] # previous output | |
# Check if all inputs are zero for this timestep | |
cond = K.all(K.equal(inputs, 0), axis=-1) | |
new_output, new_states = super().call(inputs, states, training=training) | |
# Skip timestep based on the condition | |
curr_output = K.switch(cond, h_tm1, new_output) | |
curr_states = [K.switch(cond, states[i], new_states[i]) for i in range(len(states))] | |
return curr_output, curr_states | |
class ZeroGRU(L.GRU): | |
"""Layer wrapper for the ZeroGRUCell.""" | |
# Just swap the GRUCell with ZeroGRUCell | |
def __init__(self, units, | |
activation='tanh', | |
recurrent_activation='hard_sigmoid', | |
use_bias=True, | |
kernel_initializer='glorot_uniform', | |
recurrent_initializer='orthogonal', | |
bias_initializer='zeros', | |
kernel_regularizer=None, | |
recurrent_regularizer=None, | |
bias_regularizer=None, | |
activity_regularizer=None, | |
kernel_constraint=None, | |
recurrent_constraint=None, | |
bias_constraint=None, | |
dropout=0., | |
recurrent_dropout=0., | |
implementation=1, | |
return_sequences=False, | |
return_state=False, | |
go_backwards=False, | |
stateful=False, | |
unroll=False, | |
reset_after=False, | |
**kwargs): | |
cell = ZeroGRUCell(units, | |
activation=activation, | |
recurrent_activation=recurrent_activation, | |
use_bias=use_bias, | |
kernel_initializer=kernel_initializer, | |
recurrent_initializer=recurrent_initializer, | |
bias_initializer=bias_initializer, | |
kernel_regularizer=kernel_regularizer, | |
recurrent_regularizer=recurrent_regularizer, | |
bias_regularizer=bias_regularizer, | |
kernel_constraint=kernel_constraint, | |
recurrent_constraint=recurrent_constraint, | |
bias_constraint=bias_constraint, | |
dropout=dropout, | |
recurrent_dropout=recurrent_dropout, | |
implementation=implementation, | |
reset_after=reset_after) | |
super(L.GRU, self).__init__(cell, | |
return_sequences=return_sequences, | |
return_state=return_state, | |
go_backwards=go_backwards, | |
stateful=stateful, | |
unroll=unroll, | |
**kwargs) | |
self.activity_regularizer = L.regularizers.get(activity_regularizer) | |
if __name__ == '__main__': | |
import numpy as np | |
from keras.models import Sequential | |
model = Sequential() | |
# Contextual embeddeding of symbols | |
onehot_weights = np.eye(4) | |
onehot_weights[0, 0] = 0 # Clear zero index | |
model.add(L.Embedding(4, 4, | |
trainable=False, | |
weights=[onehot_weights], | |
name='onehot')) | |
model.add(ZeroGRU(2, return_sequences=True)) | |
x = np.array([[0, 2, 1, 0, 1 ,1, 0, 0]]) | |
y = model.predict(x) | |
print(y.shape) | |
print(y) | |
# (1, 8, 2) | |
# [[[ 0. 0. ] | |
# [ 0.01389048 0.2353647 ] | |
# [-0.35381496 0.37560514] | |
# [-0.35381496 0.37560514] | |
# [-0.452064 0.48499036] | |
# [-0.46228996 0.56209606] | |
# [-0.46228996 0.56209606] | |
# [-0.46228996 0.56209606]]] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment