Skip to content

Instantly share code, notes, and snippets.

@ntakouris
Created July 1, 2020 12:31
Show Gist options
  • Save ntakouris/0da87552ba5a0bed4c2d064ceba4ef56 to your computer and use it in GitHub Desktop.
Save ntakouris/0da87552ba5a0bed4c2d064ceba4ef56 to your computer and use it in GitHub Desktop.
encoder_model = Model(encoder_input, encoder_states)
decoder_state_input_h = Input(shape=(latent_dim,))
decoder_state_input_c = Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_output, state_h, state_c = decoder_lstm_layer(
decoder_embedding,
initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]
decoder_output = decoder_dense_layer(decoder_output)
# reminder: decoder input -> decoder embedding (on the graph)
decoder_model = Model(
[decoder_input] + decoder_states_inputs,
[decoder_output] + decoder_states)
def decode_sequence(input_seq):
enc_input_seq = encoder.encode(input_seq)
padded = pad_sequences([enc_input_seq], MAX_TOKENS, padding='post')
states_value = encoder_model.predict(padded)
# Batch size is 1 this is why there is an extra sequence
target_seq = np.zeros((1, MAX_TOKENS))
# sampling recurrent loop
i = 0
target_seq[0, i] = encoder.encode(BOS)[0]
# print(target_seq)
decoded_sentence = [BOS]
while True:
i += 1
output_tokens, h, c = decoder_model.predict(
[target_seq] + states_value)
# print(output_tokens)
# argmax the output to get next token
sampled_token_index = np.argmax(output_tokens[0, i, :])
sampled_word = UNK
if sampled_token_index != 0:
sampled_word = encoder.decode([sampled_token_index])[0]
decoded_sentence += [sampled_word]
# if max length or EOS, stop
if (sampled_word == EOS or
len(decoded_sentence) == MAX_TOKENS):
break
# update the target sequence (of length 1).
# target_seq = np.zeros((1, 1, MAX_TOKENS))
target_seq[0, i] = sampled_token_index
states_value = [h, c]
return decoded_sentence
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment