Created
August 23, 2017 22:00
-
-
Save hoavt-54/beb79cea7fbb19cbf91e9aeefa168c16 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def __call__(self, inputs, ctx, state, scope=None): | |
"""Long short-term memory cell (LSTM).""" | |
with vs.variable_scope(scope or "basic_lstm_cell"): | |
# Parameters of gates are concatenated into one multiply for efficiency. | |
if self._state_is_tuple: | |
c, h = state #[batch_size, hidden_dim] | |
else: | |
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1) | |
#reshape ctx first since now its shape is [batch_size, num_feats, feat_dim] | |
ctx_dim = 100 | |
Wd_att = vs.get_variable("Wd_att", [h.get_shape()[1], ctx_dim], dtype=c.dtype) | |
U_att = vs.get_variable("U_att", [ctx_dim, 1], dtype=c.dtype) | |
c_att = vs.get_variable("c_att", [1], dtype=c.dtype) | |
pstate = math_ops.matmul(h, Wd_att) #[batch_size, ctx_dim] | |
#tile pstate to match ctx shape of [batch_size, num_feats, ctx_dim] | |
pstate = array_ops.tile(pstate, [1, 49]) | |
pstate = array_ops.reshape(pstate, [-1, ctx_dim]) | |
pstate = pstate + ctx #[batch_size * num_feats, ctx_dim] | |
pstate = tanh(pstate) | |
e_ti = math_ops.matmul(pstate, U_att) + c_att #[batch_size * num_feats] | |
e_ti = array_ops.reshape(e_ti, [-1, 49]) | |
alpha = nn_ops.softmax(logits=e_ti) | |
alpha = array_ops.tile(alpha, [1, ctx_dim]) #[batch_size, 49 * ctx_dim] | |
alpha = array_ops.reshape(alpha, [-1, 49, ctx_dim]) | |
ctx = array_ops.reshape(ctx, [-1, 49, ctx_dim]) | |
z = math_ops.multiply(alpha, ctx) | |
z = math_ops.reduce_sum(z, 1) | |
print("alpha: ", z.get_shape()) | |
concat = _linear([inputs, h, z], 4 * self._num_units, True, scope=scope) | |
# i = input_gate, j = new_input, f = forget_gate, o = output_gate | |
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) | |
new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * | |
self._activation(j)) | |
new_h = self._activation(new_c) * sigmoid(o) | |
if self._state_is_tuple: | |
new_state = LSTMStateTuple(new_c, new_h) | |
else: | |
new_state = array_ops.concat([new_c, new_h], 1) | |
return new_h, new_state |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment