Last active
April 20, 2020 06:39
-
-
Save ChuaCheowHuan/c1e2136cb06b3ddb56c4039ba904947f to your computer and use it in GitHub Desktop.
Fast-slow LSTM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#class FSRNNCell(tf.contrib.rnn.RNNCell): | |
class FSRNNCell(tf.compat.v1.nn.rnn_cell.RNNCell): | |
def __init__(self, fast_cells, slow_cell, input_keep_prob=1.0, keep_prob=1.0, training=True): | |
"""Initialize the basic Fast-Slow RNN. | |
Args: | |
fast_cells: A list of RNN cells that will be used for the fast RNN. | |
The cells must be callable, implement zero_state() and all have the | |
same hidden size, like for example tf.contrib.rnn.BasicLSTMCell. | |
slow_cell: A single RNN cell for the slow RNN. | |
keep_prob: Keep probability for the non recurrent dropout. Any kind of | |
recurrent dropout should be implemented in the RNN cells. | |
training: If False, no dropout is applied. | |
""" | |
self.fast_layers = len(fast_cells) | |
assert self.fast_layers >= 2, 'At least two fast layers are needed' | |
self.fast_cells = fast_cells | |
self.slow_cell = slow_cell | |
self.keep_prob = keep_prob | |
self.input_keep_prob = input_keep_prob | |
if not training: self.keep_prob = 1.0 | |
def __call__(self, inputs, state, scope='FS-RNN'): | |
F_state = state[0] | |
S_state = state[1] | |
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): | |
inputs = tf.nn.dropout(inputs, self.input_keep_prob) | |
with tf.variable_scope('Fast_0'): | |
#F_output, F_state = self.fast_cells[0](inputs, F_state) | |
F_output, F_state = tf.nn.dynamic_rnn(cell=self.fast_cells[0], inputs=inputs, initial_state=F_state, time_major=True) | |
F_output_drop = tf.nn.dropout(F_output, self.keep_prob) | |
with tf.variable_scope('Slow'): | |
#S_output, S_state = self.slow_cell(F_output_drop, S_state) | |
S_output, S_state = tf.nn.dynamic_rnn(cell=self.slow_cell, inputs=F_output_drop, initial_state=S_state, time_major=True) | |
S_output_drop = tf.nn.dropout(S_output, self.keep_prob) | |
with tf.variable_scope('Fast_1'): | |
#F_output, F_state = self.fast_cells[1](S_output_drop, F_state) | |
F_output, F_state = tf.nn.dynamic_rnn(cell=self.fast_cells[1], inputs=S_output_drop, initial_state=F_state, time_major=True) | |
for i in range(2, self.fast_layers): | |
with tf.variable_scope('Fast_' + str(i)): | |
# Input cannot be empty for many RNN cells | |
#F_output, F_state = self.fast_cells[i](F_output[:, 0:1] * 0.0, F_state) | |
#F_output, F_state = tf.nn.dynamic_rnn(cell=self.fast_cells[i], inputs=F_output[:, 0:1] * 0.0, initial_state=F_state, time_major=True) | |
F_output, F_state = tf.nn.dynamic_rnn(cell=self.fast_cells[i], inputs=F_output, initial_state=F_state, time_major=True) | |
F_output_drop = tf.nn.dropout(F_output, self.keep_prob) | |
return F_output_drop, (F_state, S_state) | |
def zero_state(self, batch_size, dtype): | |
F_state = self.fast_cells[0].zero_state(batch_size, dtype) | |
S_state = self.slow_cell.zero_state(batch_size, dtype) | |
return (F_state, S_state) | |
#Create one Slow and three Fast cells | |
slow = tf.contrib.rnn.BasicLSTMCell(32) # size_cell | |
fast = [tf.contrib.rnn.BasicLSTMCell(32), | |
tf.contrib.rnn.BasicLSTMCell(32), | |
tf.contrib.rnn.BasicLSTMCell(32)] | |
#Create a single FS-RNN using the cells | |
fs_lstm = FSRNNCell(fast, slow) | |
#Get initial state and create tf op to run one timestep | |
init_state = fs_lstm.zero_state(1, tf.float32) # batch_size | |
output, final_state = fs_lstm(np.zeros((12, 1, 11), np.float32), init_state) # (batch_size, state_size) | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
print(sess.run(output).shape) # (12,32) = (batch_size, cell_size) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment