Created
July 24, 2019 01:26
-
-
Save ericl/87bf3a063e8a271981a879223b8a9355 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MaskingLayerRNNmodel(TFModelV2): | |
def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kw): | |
super(MaskingLayerRNNmodel, self).__init__(obs_space, action_space, num_outputs, model_config, name, **kw) | |
self.initialize_lstm_with_prev_state = model_config['custom_options']['initialize_lstm_with_prev_state'] | |
self.input_layer = tf.keras.layers.Input( | |
shape=(None, obs_space.shape[0]), | |
name='inputLayer') | |
self.state_in_c = tf.keras.layers.Input( | |
shape=(model_config['lstm_cell_size']), | |
name='c') | |
self.state_in_h = tf.keras.layers.Input( | |
shape=(model_config['lstm_cell_size']), | |
name='h') | |
self.seq_in = tf.keras.layers.Input( | |
shape=(), | |
name='seqLens') | |
dense_layer_1 = tf.keras.layers.Dense( | |
model_config['fcnet_hiddens'][0], | |
activation=tf.nn.relu, | |
name='denseLayer1')(self.input_layer) | |
# masking_layer = tf.keras.layers.Masking( | |
# mask_value=0.0)(dense_layer_1) | |
lstm_out, state_h, state_c = tf.keras.layers.LSTM( | |
model_config['lstm_cell_size'], | |
return_sequences=True, | |
return_state=True, | |
name='lstmLayer')(inputs=dense_layer_1, | |
mask=tf.sequence_mask(self.seq_in) if self.model_config['max_seq_len'] > 1 else None, | |
initial_state=[self.state_in_c, self.state_in_h]) # note that initial_states=None (not correct), how could we pass 'state' here? | |
# if we had access to batch shape, we could set stateful=True in LSTM and call reset_states() instead of passing state | |
# reshape_layer does not accept mask which is propogated through model if Masking() is used upstream, FAILS! | |
reshape_layer = tf.keras.layers.Lambda( | |
lambda x: tf.reshape(x, [-1, model_config['lstm_cell_size']]))(lstm_out) | |
dense_layer_2 = tf.keras.layers.Dense( | |
model_config['fcnet_hiddens'][1], | |
activation=tf.nn.relu, | |
name='denseLayer2')(lstm_out) | |
logits_layer = tf.keras.layers.Dense( | |
self.num_outputs, | |
activation=tf.keras.activations.linear, | |
name='logitsLayer')(dense_layer_2) | |
value_layer = tf.keras.layers.Dense( | |
1, | |
activation=None, | |
name='valueLayer')(dense_layer_2) | |
state = [state_h, state_c] | |
self.base_model = tf.keras.Model(inputs=[self.input_layer, self.state_in_c, self.state_in_h, self.seq_in], outputs=[logits_layer, value_layer, state_h, state_c]) | |
self.register_variables(self.base_model.variables) | |
self.base_model.summary() | |
# Implement the core forward method | |
def forward(self, input_dict, state, seq_lens): | |
x = input_dict['obs'] | |
if x._rank() < 3: | |
x = add_time_dimension(x, seq_lens) | |
logits, self._value_out, h, c = self.base_model((x, state[0], state[1], seq_lens)) | |
return tf.reshape(logits, [-1, 2]), [h, c] | |
def get_initial_state(self): | |
return [np.zeros(self.model_config['lstm_cell_size'], np.float32), | |
np.zeros(self.model_config['lstm_cell_size'], np.float32)] | |
def value_function(self): | |
return tf.reshape(self._value_out, [-1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment