Last active
February 15, 2019 01:57
-
-
Save seanie12/bdb2e335efc9d44c34ab234d295d7952 to your computer and use it in GitHub Desktop.
sequence to sequence with attention in Keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import os | |
import numpy as np | |
# settings for GPU | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
config = tf.ConfigProto() | |
config.gpu_options.allow_growth = True | |
config.gpu_options.per_process_gpu_memory_fraction = 0.9 | |
sess = tf.Session(config=config) | |
# hyper-parameters | |
batch_size = 256 | |
epochs = 30 | |
latent_dim = 256 | |
embedding_size = 128 | |
num_samples = 10000 | |
pad_token = "<PAD>" | |
data_path = "fra-eng/fra.txt" | |
input_texts = [] | |
target_texts = [] | |
input_words = set() | |
target_words = set() | |
with open(data_path, "r", encoding="utf-8") as f: | |
lines = f.read().split("\n") | |
for line in lines[: min(num_samples, len(lines) - 1)]: | |
input_text, target_text = line.split("\t") | |
# <GO> as the "start sequence" character | |
# <EOS> as "end sequence" character | |
target_text = "<GO> " + target_text + " <EOS>" | |
input_texts.append(input_text) | |
target_texts.append(target_text) | |
# construct the set of characters for each language | |
for word in input_text.split(): | |
if word not in input_words: | |
input_words.add(word) | |
for word in target_text.split(): | |
if word not in target_words: | |
target_words.add(word) | |
input_words = sorted(list(input_words)) | |
target_words = sorted(list(target_words)) | |
num_encoder_tokens = len(input_words) | |
num_decoder_tokens = len(target_words) | |
max_encoder_seq_length = max([len(txt) for txt in input_texts]) | |
max_decoder_seq_length = max([len(txt) for txt in target_texts]) | |
print('Number of samples:', len(input_texts)) | |
print('Number of unique input tokens:', num_encoder_tokens) | |
print('Number of unique output tokens:', num_decoder_tokens) | |
print('Max sequence length for inputs:', max_encoder_seq_length) | |
print('Max sequence length for outputs:', max_decoder_seq_length) | |
# token2idx dictionary | |
input_token_index = {word: i for i, word in enumerate(input_words, start=1)} | |
target_token_index = {word: i for i, word in enumerate(target_words)} | |
# 0 for pad_token idx | |
input_token_index[pad_token] = 0 | |
# construct zero numpy array | |
encoder_input_data = np.zeros((len(input_texts), max_encoder_seq_length) | |
, dtype=np.float32) | |
decoder_input_data = np.zeros((len(target_texts), max_decoder_seq_length), | |
dtype=np.float32) | |
decoder_target_data = np.zeros((len(target_texts), max_decoder_seq_length, num_decoder_tokens + 1), | |
dtype=np.float32) | |
# fill in the zero-numpy array | |
for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)): | |
for t, word in enumerate(input_text.split()): | |
encoder_input_data[i, t] = input_token_index[word] | |
for t, word in enumerate(target_text.split()): | |
decoder_input_data[i, t] = target_token_index[word] | |
# decoder target is one time step ahead of decoder_input | |
if t > 0: | |
decoder_target_data[i, t - 1, target_token_index[word]] = 1 | |
# encoder parts | |
encoder_inputs = tf.keras.Input(shape=[None], name="encoder_inputs") | |
# 0 for pad_token so total input dim is V + 1 | |
encoder_embedding = tf.keras.layers.Embedding(input_dim=num_encoder_tokens + 1, | |
output_dim=embedding_size, | |
mask_zero=True, | |
name="encoder_embedding") | |
encoder_embedded = encoder_embedding(encoder_inputs) | |
encoder_lstm = tf.keras.layers.LSTM(latent_dim, | |
return_state=True, | |
return_sequences=True) | |
# bi-directional lstm | |
encoder = tf.keras.layers.Bidirectional(encoder_lstm) | |
encoder_outputs, fw_state_h, bw_state_h, fw_state_c, bw_state_c = encoder(encoder_embedded) | |
decoder_inputs = tf.keras.Input(shape=[None]) | |
decoder_embedding = tf.keras.layers.Embedding(input_dim=num_decoder_tokens + 1, | |
output_dim=embedding_size, | |
mask_zero=True) | |
decoder_embedded = decoder_embedding(decoder_inputs) | |
decoder_lstm = tf.keras.layers.LSTM(latent_dim, | |
return_sequences=True, | |
return_state=True) | |
decoder_outputs, _, _ = decoder_lstm(decoder_embedded) | |
# attention e^t_i = v_t tanh(W(h^{enc}_i; h^{dec}^t} + b) | |
def attention(inputs): | |
# inputs: [encoder_outputs, decoder_outputs] | |
# encoder_outputs : [batch, t, 2d], decoder_outputs :[b, k, d] | |
encoder_outputs = inputs[0] | |
decoder_outputs = inputs[1] | |
encoder_length = tf.shape(encoder_outputs)[1] | |
decoder_length = tf.shape(decoder_outputs)[1] | |
# encoder_hiddens : [b, t, 2d] -> [b, k, t, 2d] | |
# decoder_hiddens : [b, k, d] -> [b, k, t, d] | |
encoder_hiddens = tf.tile(tf.expand_dims(encoder_outputs, axis=1), [1, decoder_length, 1, 1]) | |
decoder_hiddens = tf.tile(tf.expand_dims(decoder_outputs, axis=2), [1, 1, encoder_length, 1]) | |
# hidden_input : [b,t,k,3d] | |
hidden_input = tf.concat([encoder_hiddens, decoder_hiddens], axis=-1) | |
# w := tanh(W[h_enc; h_dec] + b) | |
attention_hidden = tf.keras.layers.Dense(latent_dim, | |
activation=tf.nn.tanh, | |
use_bias=True)(hidden_input) | |
# v^t dot w | |
attention_score = tf.keras.layers.Dense(1)(attention_hidden) | |
attention_score = tf.squeeze(attention_score, axis=-1) | |
# attention mask | |
encoder_mask = tf.sign(tf.abs(tf.reduce_sum(encoder_outputs, axis=2))) | |
decoder_mask = tf.sign(tf.abs(tf.reduce_sum(decoder_outputs, axis=2))) | |
query_masks = tf.expand_dims(decoder_mask, 2) #[b, k, 1] | |
query_masks = tf.tile(query_masks, [1, 1, encoder_length]) # [b, k, t] | |
paddings = tf.ones_like(attention_score) * (-2 ** 32 + 1) | |
attention_score = tf.where(tf.equal(query_masks, 0), paddings, attention_score) | |
attention_score = tf.nn.softmax(attention_score, axis=-1) | |
key_mask = tf.expand_dims(encoder_mask, 1) # [b, t, 1] | |
attention_score *= key_mask | |
# [b, k , t] dot [b, t, 2d] | |
context_vectors = tf.matmul(attention_score, encoder_outputs) | |
augmented_outputs = tf.concat([decoder_outputs, context_vectors], axis=-1) | |
return augmented_outputs | |
attention_layer = tf.keras.layers.Lambda(attention) | |
decoder_outputs = attention_layer([encoder_outputs, decoder_outputs]) | |
decoder_dense = tf.keras.layers.Dense(num_decoder_tokens + 1, activation=tf.nn.softmax) | |
decoder_outputs = decoder_dense(decoder_outputs) | |
model = tf.keras.Model(inputs=[encoder_inputs, decoder_inputs], | |
outputs=decoder_outputs) | |
model.summary() | |
model.compile(optimizer=tf.train.RMSPropOptimizer(1e-3), | |
loss=tf.keras.losses.categorical_crossentropy) | |
model.fit([encoder_input_data, decoder_input_data], | |
decoder_target_data, | |
batch_size=batch_size, | |
epochs=epochs, | |
validation_split=0.2) | |
model.save("s2s.h5") | |
# Next: inference mode (sampling). | |
# Here's the drill: | |
# 1) encode input and retrieve initial decoder state | |
# 2) run one step of decoder with this initial state | |
# and a "start of sequence" token as target. | |
# Output will be the next target token | |
# 3) Repeat with the current target token and current states | |
# Define sampling models | |
encoder_model = tf.keras.Model(inputs=encoder_inputs, | |
outputs=encoder_outputs) | |
# to compute context vectors, we need encoder_outputs for each decoder time step | |
encoder_hidden_outputs = tf.keras.Input(shape=(None, 2 * latent_dim)) | |
decoder_state_input_h = tf.keras.Input(shape=(latent_dim,)) | |
decoder_state_input_c = tf.keras.Input(shape=(latent_dim,)) | |
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c] | |
decoder_embedded = decoder_embedding(decoder_inputs) | |
decoder_outputs, state_h, state_c = decoder_lstm(decoder_embedded, | |
initial_state=decoder_states_inputs) | |
decoder_outputs = attention_layer([encoder_hidden_outputs, decoder_outputs]) | |
decoder_states = [state_h, state_c] | |
decoder_outputs = decoder_dense(decoder_outputs) | |
decoder_model = tf.keras.Model(inputs=[decoder_inputs, encoder_hidden_outputs] + decoder_states_inputs, | |
outputs=[decoder_outputs] + decoder_states) | |
reverse_input_word_index = {i: word for word, i in input_token_index.items()} | |
reverse_target_word_index = {i: word for word, i in target_token_index.items()} | |
def decode_sequence(input_seq): | |
encoder_outputs = encoder_model.predict(input_seq) | |
# for the first time step of decoder, initial state is zero vectors | |
zeros = np.zeros((max_decoder_seq_length, latent_dim)) | |
states_value = [zeros, zeros] | |
target_seq = np.zeros((1, num_decoder_tokens + 1)) | |
target_seq[0, 0] = target_token_index["<GO>"] | |
decoded_sentence = "" | |
stop_condition = False | |
while not stop_condition: | |
output_tokens, h, c = decoder_model.predict([target_seq, encoder_outputs] + states_value) | |
sampled_token_index = np.argmax(output_tokens[0, -1, :]) | |
sampled_char = reverse_target_word_index[sampled_token_index] | |
decoded_sentence += " " + sampled_char | |
if sampled_char == "<EOS>" or len(decoded_sentence.split()) > max_decoder_seq_length: | |
stop_condition = True | |
target_seq = np.zeros((1, num_decoder_tokens + 1)) | |
target_seq[0, 0] = target_token_index[sampled_char] | |
states_value = [h, c] | |
return decoded_sentence | |
with open("example.txt", "w", encoding="utf-8") as f: | |
for seq_index in range(100): | |
input_seq = encoder_input_data[seq_index: seq_index + 1] | |
decoded_sentence = decode_sequence(input_seq) | |
print("-") | |
input_text = input_texts[seq_index] | |
f.write(input_text + "\t" + decoded_sentence + "\n") | |
print("input sentence:", input_texts[seq_index]) | |
print("decoded sentence:", decoded_sentence) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment