Last active
February 19, 2019 06:51
-
-
Save seanie12/ef31411438cfbfc3c93fdccceb0cda95 to your computer and use it in GitHub Desktop.
char-level seq2seq with attention
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | |
import random | |
import numpy as np | |
import re, unicodedata | |
random.seed(1024) | |
# gpu configuration | |
USE_CUDA = torch.cuda.is_available() | |
gpus = [0] | |
torch.cuda.set_device(gpus[0]) | |
FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor | |
LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor | |
ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor | |
flatten = lambda lst: [item for sublist in lst for item in sublist] | |
def get_batch(batch_size, train_data): | |
random.shuffle(train_data) | |
start_idx = 0 | |
end_idx = batch_size | |
while end_idx < len(train_data): | |
batch = train_data[start_idx: end_idx] | |
temp = end_idx | |
end_idx = end_idx + batch_size | |
start_idx = temp | |
yield batch | |
if end_idx >= len(train_data): | |
yield train_data[start_idx:] | |
def pad_to_batch(batch, source2idx, target2idx): | |
# sort data point as descending order of seq_length | |
sorted_batch = sorted(batch, key=lambda x: x[0].size(1), reverse=True) | |
x, y = list(zip(*sorted_batch)) | |
max_source_length = max([seq.size(1) for seq in x]) | |
max_target_length = max([seq.size(1) for seq in y]) | |
padded_x, padded_y = [], [] | |
for i in range(len(batch)): | |
if x[i].size(1) < max_source_length: | |
paddings = LongTensor([source2idx["<PAD>"]] * (max_source_length - x[i].size(1))) | |
paddings = paddings.view(1, -1) | |
padded = torch.cat([x[i], paddings], dim=1) | |
padded_x.append(padded) | |
else: | |
padded_x.append(x[i]) | |
if y[i].size(1) < max_target_length: | |
paddings = LongTensor([target2idx["<PAD>"]] * (max_target_length - y[i].size(1))) | |
paddings = paddings.view(1, -1) | |
padded = torch.cat([y[i], paddings], dim=1) | |
padded_y.append(padded) | |
else: | |
padded_y.append(y[i]) | |
input_x = torch.cat(padded_x, dim=0) | |
target_y = torch.cat(padded_y, dim=0) | |
input_len = [list(map(lambda s: s == 0, t.data)).count(False) for t in input_x] | |
target_len = [list(map(lambda s: s == 0, t.data)).count(False) for t in target_y] | |
return input_x, target_y, input_len, target_len | |
# convert tokens to indices | |
def preapre_sequence(seq, to_idx): | |
indices = list(map(lambda w: to_idx[w] if w in to_idx else to_idx["<UNK>"], seq)) | |
return LongTensor(indices) | |
def unicode_to_ascii(s): | |
return "".join( | |
c for c in unicodedata.normalize("NFD", s) | |
if unicodedata.category(c) != "Mn" | |
) | |
def normalize_string(s): | |
s = unicode_to_ascii(s.lower().strip()) | |
s = re.sub(r"([,.!?])", r" \1 ", s) | |
s = re.sub(r"[^a-zA-Z,.!?]+", r" ", s) | |
s = re.sub(r"\s+", r" ", s).strip() | |
return s | |
corpus = open("fra-eng/fra.txt", "r", encoding="utf-8").readlines()[:-1] | |
x_raw, y_raw = [], [] | |
for parallel in corpus: | |
source, target = parallel[:-1].split("\t") | |
if source.strip() == "" or target.strip() == "": | |
continue | |
normalized_source = normalize_string(source).split() | |
normalized_target = normalize_string(target).split() | |
x_raw.append(normalized_source) | |
y_raw.append(normalized_target) | |
# construct vocab for source and target language | |
source_vocab = list(set(flatten(x_raw))) | |
target_vocab = list(set(flatten(y_raw))) | |
source2idx = {"<PAD>": 0, "<UNK>": 1, "<s>": 2, "</s>": 3} | |
for vocab in source_vocab: | |
if vocab not in source2idx: | |
source2idx[vocab] = len(source2idx) | |
idx2source = {idx: vocab for vocab, idx in source2idx.items()} | |
target2idx = {"<PAD>": 0, "<UNK>": 1, "<s>": 2, "</s>": 3} | |
for vocab in target_vocab: | |
if vocab not in target2idx: | |
target2idx[vocab] = len(target2idx) | |
idx2target = {idx: vocab for vocab, idx in target2idx.items()} | |
padded_x, padded_y = [], [] | |
for source, target in zip(x_raw, y_raw): | |
padded_x.append(preapre_sequence(source + ["</s>"], source2idx).view(1, -1)) | |
padded_y.append(preapre_sequence(target + ["</s>"], target2idx).view(1, -1)) | |
train_data = list(zip(padded_x, padded_y)) | |
class Encoder(nn.Module): | |
def __init__(self, input_size, embedding_size, hidden_size, n_layers=1, bidirectional=False): | |
super(Encoder, self).__init__() | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.n_layers = n_layers | |
self.embedding = nn.Embedding(input_size, embedding_size) | |
if bidirectional: | |
self.n_direction = 2 | |
self.gru = nn.GRU(embedding_size, hidden_size, n_layers, batch_first=True, bidirectional=True) | |
else: | |
self.n_direction = 1 | |
self.gru = nn.GRU(embedding_size, hidden_size, n_layers, batch_first=True, bidirectional=False) | |
def init_hidden(self, inputs): | |
hidden = torch.zeros((self.n_layers * self.n_direction, inputs.size(0), self.hidden_size)) | |
return hidden.cuda() if USE_CUDA else hidden | |
def init_weight(self): | |
self.embedding.weight = nn.init.xavier_uniform_(self.embedding.weight) | |
self.gru.weight_hh_l0 = nn.init.xavier_uniform_(self.gru.weight_hh_l0) | |
self.gru.weight_ih_l0 = nn.init.xavier_uniform_(self.gru.weight_ih_l0) | |
def forward(self, inputs, input_lengths): | |
# inputs : [B, T] LongTensor | |
# input_lengths: actual lengths of input batch list() | |
hidden = self.init_hidden(inputs) | |
embedded = self.embedding(inputs) | |
packed = pack_padded_sequence(embedded, input_lengths, batch_first=True) | |
# ouptuts :[B, T, 2D], hidden: [n_layers * n_direction, B, D] | |
outputs, hidden = self.gru(packed, hidden) | |
outputs, output_lengths = pad_packed_sequence(outputs, batch_first=True) | |
if self.n_layers > 1: | |
if self.n_direction == 2: | |
hidden = hidden[-2:] | |
else: | |
hidden = hidden[-1] | |
hidden = torch.cat([h for h in hidden], dim=1) # [B, D * n_layers * n_direction] | |
hidden = hidden.unsqueeze(dim=1) # [B, 1, D'] | |
return outputs, hidden | |
class Decoder(nn.Module): | |
def __init__(self, input_size, embedding_size, | |
hidden_size, n_layers=1, dropout_prob=0.1): | |
super(Decoder, self).__init__() | |
self.hidden_size = hidden_size | |
self.n_layers = n_layers | |
self.embedding = nn.Embedding(input_size, embedding_size) | |
self.dropout = nn.Dropout(dropout_prob) | |
# because input is [embedded, context_vector], its size is embedding_size + hidden_size | |
self.gru = nn.GRU(embedding_size + hidden_size, hidden_size, n_layers, batch_first=True) | |
self.linear = nn.Linear(hidden_size * 2, input_size) # for final softmax layer | |
self.attn_linear = nn.Linear(self.hidden_size, self.hidden_size) | |
# for bahdanau attention | |
self.encoder_linear = nn.Linear(hidden_size, hidden_size, bias=False) | |
self.decoder_linear = nn.Linear(hidden_size, hidden_size) | |
self.attention_output = nn.Linear(hidden_size, 1) | |
def init_hidden(self, inputs): | |
hidden = torch.zeros((self.n_layers, inputs.size(0), self.hidden_size)) | |
return hidden.cuda() if USE_CUDA else hidden | |
def init_weigh(self): | |
self.embedding.weight = nn.init.xavier_uniform_(self.embedding.weight) | |
self.gru.weight_hh_l0 = nn.init.xavier_uniform_(self.gru.weight_hh_l0) | |
self.gru.weight_ih_l0 = nn.init.xavier_uniform_(self.gru.weight_ih_l0) | |
self.attn_linear.weight = nn.init.xavier_uniform_(self.attn_linear.weight) | |
def bahdanau_attention(self, decoder_hidden, encoder_outputs, encoder_mask): | |
# decoder_hidden : [1, b, d] | |
# encoder_outputs :[b, t, d] | |
# encoder_maskings : [b, t] ByteTensor | |
decoder_hidden = decoder_hidden[0].unsqueeze(dim=1) # [b, 1, d] | |
decoder_feature = self.decoder_linear(decoder_hidden) | |
encoder_features = self.encoder_linear(encoder_outputs) | |
attn_features = encoder_features + decoder_feature | |
attn_features = F.tanh(attn_features) | |
score = self.attention_output(attn_features).squeeze(dim=2) # [b,t,1] -> [b,t] | |
if encoder_mask is not None: | |
score = score.masked_fill(encoder_mask, value=-1e12) | |
score = F.softmax(score, dim=1) | |
score = score.unsqueeze(dim=1) | |
context_vector = torch.matmul(score, encoder_outputs) | |
return context_vector, score | |
def luong_attention(self, decoder_hidden, encoder_outputs, encoder_mask): | |
# decoder_hidden : [1, b, d] | |
# encoder_outputs :[b, t, d] | |
# encoder_maskings : [b, t] ByteTensor | |
decoder_hidden = decoder_hidden[0].unsqueeze(2) # [1, b, d] -> [b,d,1] | |
energies = self.attn_linear(encoder_outputs) # [b, t, d] | |
attention_energies = torch.matmul(energies, decoder_hidden).squeeze(2) | |
# mask | |
if encoder_mask is not None: | |
attention_energies = attention_energies.masked_fill(encoder_mask, value=-1e12) | |
alpha = F.softmax(attention_energies, dim=1) | |
alpha = alpha.unsqueeze(1) # [b,1,t] | |
context_vector = torch.matmul(alpha, encoder_outputs) # [b, 1, d] | |
return context_vector, alpha | |
def forward(self, inputs, context, max_length, encoder_outputs, encoder_maskings, is_training=False): | |
""" | |
:param inputs: [b,1] LongTensor, Start_symbol | |
:param context: [b, 1, d] | |
:param max_length: max length to decode | |
:param encoder_outputs: [b, t, d] | |
:param encoder_maskings: [b, t] ByteTensor | |
:param is_training: boolean | |
:return: | |
""" | |
embedded = self.embedding(inputs) | |
hidden = self.init_hidden(inputs) | |
if is_training: | |
embedded = self.dropout(embedded) | |
decode = [] | |
# unroll gru | |
for t in range(max_length): | |
concat_inputs = torch.cat((embedded, context), dim=2) | |
_, hidden = self.gru(concat_inputs, hidden) | |
concat_outputs = torch.cat((hidden, context.transpose(0, 1)), dim=2) | |
logit = self.linear(concat_outputs.squeeze(0)) | |
score = F.log_softmax(logit, 1) | |
decode.append(logit) | |
decoded = torch.argmax(score, dim=1) | |
# input for next time step because it is not teacher-forcing training | |
embedded = self.embedding(decoded).unsqueeze(1) | |
if is_training: | |
embedded = self.dropout(embedded) | |
context, alpha = self.bahdanau_attention(hidden, encoder_outputs, encoder_maskings) | |
scores = torch.cat(decode, dim=1) | |
batch_size = inputs.size(0) | |
return scores.view(batch_size * max_length, -1) | |
def decode(self, context, encoder_outputs, max_decode_length=100): | |
start_decode = LongTensor([[target2idx["<s>"]] * 1]) | |
embedded = self.embedding(start_decode) | |
hidden = self.init_hidden(start_decode) | |
decodes = [] | |
attentions = [] | |
decoded = torch.Tensor([target2idx["<s>"]]) | |
while decoded.tolist()[0] != target2idx["</s>"] and len(decodes) < max_decode_length: | |
concat_input = torch.cat((embedded, context), dim=2) | |
_, hidden = self.gru(concat_input, hidden) | |
concat_output = torch.cat((hidden, context.transpose(0, 1)), dim=2) | |
score = self.linear(concat_output.squeeze(0)) # [1, d] | |
score = F.log_softmax(score, dim=1) | |
decodes.append(score) | |
decoded = torch.argmax(score, dim=1) # [1] | |
embedded = self.embedding(decoded).unsqueeze(1) # [1,d] -> [1, 1, d] | |
context, alpha = self.attention(hidden, encoder_outputs, None) | |
attentions.append(alpha.squeeze(1)) | |
indices = torch.cat(decodes, dim=0).max(1)[1] | |
attentions = torch.cat(attentions, dim=0) | |
return indices, attentions | |
num_epochs = 50 | |
batch_size = 64 | |
embedding_size = 300 | |
hidden_size = 512 | |
lr = 1e-3 | |
decoder_lr_ratio = 5.0 | |
rescheduled = False | |
encoder = Encoder(len(source2idx), embedding_size, hidden_size, 3, True) | |
decoder = Decoder(len(target2idx), embedding_size, hidden_size * 2) | |
encoder.init_weight() | |
decoder.init_weigh() | |
if USE_CUDA: | |
encoder = encoder.cuda() | |
decoder = decoder.cuda() | |
loss_function = nn.CrossEntropyLoss(ignore_index=0) | |
enc_optimizer = optim.Adam(encoder.parameters(), lr=lr) | |
dec_optimizer = optim.Adam(decoder.parameters(), lr=lr * decoder_lr_ratio) | |
for epoch in range(num_epochs): | |
losses = [] | |
for i, batch in enumerate(get_batch(batch_size, train_data)): | |
# prepare inputs | |
inputs, targets, input_lengths, target_lengths = pad_to_batch(batch, source2idx, target2idx) | |
zeros = torch.zeros_like(inputs) | |
input_masks = ByteTensor(inputs == zeros) | |
start_decode = LongTensor([[target2idx["<s>"]]] * targets.size(0)) | |
encoder.zero_grad() | |
decoder.zero_grad() | |
output, hidden_c = encoder(inputs, input_lengths) | |
preds = decoder(start_decode, hidden_c, targets.size(1), output, input_masks, True) | |
loss = loss_function(preds, targets.view(-1)) | |
losses.append(loss.tolist()) | |
loss.backward() | |
nn.utils.clip_grad_norm_(encoder.parameters(), 5.0) | |
nn.utils.clip_grad_norm_(decoder.parameters(), 5.0) | |
enc_optimizer.step() | |
dec_optimizer.step() | |
if i % 200 == 0: | |
print("[%02d/%d] [%03d/%d] mean_loss : %0.2f" % ( | |
epoch, num_epochs, i, len(train_data) // batch_size, np.mean(losses))) | |
if rescheduled is False and epoch == num_epochs // 2: | |
lr *= 0.01 | |
enc_optimizer = optim.Adam(encoder.parameters(), lr=lr) | |
dec_optimizer = optim.Adam(decoder.parameters(), lr=lr * decoder_lr_ratio) | |
rescheduled = True | |
test = train_data[0] | |
source = test[0] | |
target = test[1] | |
input_text = [idx2source[idx] for idx in source.tolist()[0]] | |
target_text = [idx2target[idx] for idx in target.tolist()[0]] | |
output, hidden = encoder(source, [source.size(1)]) | |
pred, attn = decoder.decode(hidden, output) | |
pred_text = [idx2target[idx] for idx in pred.tolist()] | |
print("source : ", " ".join(input_text)) | |
print("target :", " ".join(target_text)) | |
print(pred_text) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment