Created
August 25, 2020 16:11
-
-
Save VinACE/ca50bfa10252e98592376a61d9915f22 to your computer and use it in GitHub Desktop.
seq2seq_test_example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
# https://github.com/bentrevett/pytorch-seq2seq/issues/129 | |
# https://github.com/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchtext | |
from torchtext.datasets import Multi30k | |
from torchtext.data import Field, BucketIterator | |
import matplotlib.pyplot as plt | |
import matplotlib.ticker as ticker | |
import spacy | |
import numpy as np | |
import random | |
import math | |
import time | |
from IPython.core.debugger import set_trace #set_trace() | |
######## ENCODER PART ################################# | |
class Encoder(nn.Module): | |
def __init__(self, | |
input_dim, | |
hid_dim, | |
n_layers, | |
n_heads, | |
pf_dim, | |
dropout, | |
device, | |
max_length = 100): | |
super().__init__() | |
self.device = device | |
self.tok_embedding = nn.Embedding(input_dim, hid_dim) | |
self.pos_embedding = nn.Embedding(max_length, hid_dim) | |
self.layers = nn.ModuleList([EncoderLayer(hid_dim, | |
n_heads, | |
pf_dim, | |
dropout, | |
device) | |
for _ in range(n_layers)]) | |
self.dropout = nn.Dropout(dropout) | |
self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) | |
def forward(self, src, src_mask): | |
#src = [batch size, src len] | |
#src_mask = [batch size, src len] | |
batch_size = src.shape[0] | |
src_len = src.shape[1] | |
pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device) | |
#pos = [batch size, src len] | |
src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos)) | |
#src = [batch size, src len, hid dim] | |
for layer in self.layers: | |
src = layer(src, src_mask) | |
#src = [batch size, src len, hid dim] | |
return src | |
class EncoderLayer(nn.Module): | |
def __init__(self, | |
hid_dim, | |
n_heads, | |
pf_dim, | |
dropout, | |
device): | |
super().__init__() | |
self.self_attn_layer_norm = nn.LayerNorm(hid_dim) | |
self.ff_layer_norm = nn.LayerNorm(hid_dim) | |
self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) | |
self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, | |
pf_dim, | |
dropout) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, src, src_mask): | |
#src = [batch size, src len, hid dim] | |
#src_mask = [batch size, src len] | |
#self attention | |
_src, _ = self.self_attention(src, src, src, src_mask) | |
#dropout, residual connection and layer norm | |
src = self.self_attn_layer_norm(src + self.dropout(_src)) | |
#src = [batch size, src len, hid dim] | |
#positionwise feedforward | |
_src = self.positionwise_feedforward(src) | |
#dropout, residual and layer norm | |
src = self.ff_layer_norm(src + self.dropout(_src)) | |
#src = [batch size, src len, hid dim] | |
return src | |
#### ATTENTION LAYER ################################################################# | |
class MultiHeadAttentionLayer(nn.Module): | |
def __init__(self, hid_dim, n_heads, dropout, device): | |
super().__init__() | |
assert hid_dim % n_heads == 0 | |
self.hid_dim = hid_dim | |
self.n_heads = n_heads | |
self.head_dim = hid_dim // n_heads | |
self.fc_q = nn.Linear(hid_dim, hid_dim) | |
self.fc_k = nn.Linear(hid_dim, hid_dim) | |
self.fc_v = nn.Linear(hid_dim, hid_dim) | |
self.fc_o = nn.Linear(hid_dim, hid_dim) | |
self.dropout = nn.Dropout(dropout) | |
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device) | |
def forward(self, query, key, value, mask = None): | |
batch_size = query.shape[0] | |
#query = [batch size, query len, hid dim] | |
#key = [batch size, key len, hid dim] | |
#value = [batch size, value len, hid dim] | |
Q = self.fc_q(query) | |
K = self.fc_k(key) | |
V = self.fc_v(value) | |
#Q = [batch size, query len, hid dim] | |
#K = [batch size, key len, hid dim] | |
#V = [batch size, value len, hid dim] | |
Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) | |
K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) | |
V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) | |
#Q = [batch size, n heads, query len, head dim] | |
#K = [batch size, n heads, key len, head dim] | |
#V = [batch size, n heads, value len, head dim] | |
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale | |
#energy = [batch size, n heads, query len, key len] | |
if mask is not None: | |
energy = energy.masked_fill(mask == 0, -1e10) | |
attention = torch.softmax(energy, dim = -1) | |
#attention = [batch size, n heads, query len, key len] | |
x = torch.matmul(self.dropout(attention), V) | |
#x = [batch size, n heads, query len, head dim] | |
x = x.permute(0, 2, 1, 3).contiguous() | |
#x = [batch size, query len, n heads, head dim] | |
x = x.view(batch_size, -1, self.hid_dim) | |
#x = [batch size, query len, hid dim] | |
x = self.fc_o(x) | |
#x = [batch size, query len, hid dim] | |
return x, attention | |
class PositionwiseFeedforwardLayer(nn.Module): | |
def __init__(self, hid_dim, pf_dim, dropout): | |
super().__init__() | |
self.fc_1 = nn.Linear(hid_dim, pf_dim) | |
self.fc_2 = nn.Linear(pf_dim, hid_dim) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
#x = [batch size, seq len, hid dim] | |
x = self.dropout(torch.relu(self.fc_1(x))) | |
#x = [batch size, seq len, pf dim] | |
x = self.fc_2(x) | |
#x = [batch size, seq len, hid dim] | |
return x | |
###### decoder part ############### | |
class Decoder(nn.Module): | |
def __init__(self, | |
output_dim, | |
hid_dim, | |
n_layers, | |
n_heads, | |
pf_dim, | |
dropout, | |
device, | |
max_length = 100): | |
super().__init__() | |
self.device = device | |
self.tok_embedding = nn.Embedding(output_dim, hid_dim) | |
self.pos_embedding = nn.Embedding(max_length, hid_dim) | |
self.layers = nn.ModuleList([DecoderLayer(hid_dim, | |
n_heads, | |
pf_dim, | |
dropout, | |
device) | |
for _ in range(n_layers)]) | |
self.fc_out = nn.Linear(hid_dim, output_dim) | |
self.dropout = nn.Dropout(dropout) | |
self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) | |
def forward(self, trg, enc_src, trg_mask, src_mask): | |
#trg = [batch size, trg len] | |
#enc_src = [batch size, src len, hid dim] | |
#trg_mask = [batch size, trg len] | |
#src_mask = [batch size, src len] | |
batch_size = trg.shape[0] | |
trg_len = trg.shape[1] | |
pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device) | |
#pos = [batch size, trg len] | |
trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos)) | |
#trg = [batch size, trg len, hid dim] | |
for layer in self.layers: | |
trg, attention = layer(trg, enc_src, trg_mask, src_mask) | |
#trg = [batch size, trg len, hid dim] | |
#attention = [batch size, n heads, trg len, src len] | |
output = self.fc_out(trg) | |
#output = [batch size, trg len, output dim] | |
return output, attention | |
############# decoder Layer ################################# | |
class DecoderLayer(nn.Module): | |
def __init__(self, | |
hid_dim, | |
n_heads, | |
pf_dim, | |
dropout, | |
device): | |
super().__init__() | |
self.self_attn_layer_norm = nn.LayerNorm(hid_dim) | |
self.enc_attn_layer_norm = nn.LayerNorm(hid_dim) | |
self.ff_layer_norm = nn.LayerNorm(hid_dim) | |
self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) | |
self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) | |
self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, | |
pf_dim, | |
dropout) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, trg, enc_src, trg_mask, src_mask): | |
#trg = [batch size, trg len, hid dim] | |
#enc_src = [batch size, src len, hid dim] | |
#trg_mask = [batch size, trg len] | |
#src_mask = [batch size, src len] | |
#self attention | |
_trg, _ = self.self_attention(trg, trg, trg, trg_mask) | |
#dropout, residual connection and layer norm | |
trg = self.self_attn_layer_norm(trg + self.dropout(_trg)) | |
#trg = [batch size, trg len, hid dim] | |
#encoder attention | |
_trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask) | |
#dropout, residual connection and layer norm | |
trg = self.enc_attn_layer_norm(trg + self.dropout(_trg)) | |
#trg = [batch size, trg len, hid dim] | |
#positionwise feedforward | |
_trg = self.positionwise_feedforward(trg) | |
#dropout, residual and layer norm | |
trg = self.ff_layer_norm(trg + self.dropout(_trg)) | |
#trg = [batch size, trg len, hid dim] | |
#attention = [batch size, n heads, trg len, src len] | |
return trg, attention | |
class Seq2Seq(nn.Module): | |
def __init__(self, | |
encoder, | |
decoder, | |
src_pad_idx, | |
trg_pad_idx, | |
device): | |
super().__init__() | |
self.encoder = encoder | |
self.decoder = decoder | |
self.src_pad_idx = src_pad_idx | |
self.trg_pad_idx = trg_pad_idx | |
self.device = device | |
def make_src_mask(self, src): | |
#src = [batch size, src len] | |
src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2) | |
#src_mask = [batch size, 1, 1, src len] | |
return src_mask | |
def make_trg_mask(self, trg): | |
#trg = [batch size, trg len] | |
trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2) | |
#trg_pad_mask = [batch size, 1, 1, trg len] | |
trg_len = trg.shape[1] | |
trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool() | |
#trg_sub_mask = [trg len, trg len] | |
trg_mask = trg_pad_mask & trg_sub_mask | |
#trg_mask = [batch size, 1, trg len, trg len] | |
return trg_mask | |
def forward(self, src, trg): | |
#src = [batch size, src len] | |
#trg = [batch size, trg len] | |
src_mask = self.make_src_mask(src) | |
trg_mask = self.make_trg_mask(trg) | |
#src_mask = [batch size, 1, 1, src len] | |
#trg_mask = [batch size, 1, trg len, trg len] | |
enc_src = self.encoder(src, src_mask) | |
#enc_src = [batch size, src len, hid dim] | |
output, attention = self.decoder(trg, enc_src, trg_mask, src_mask) | |
#output = [batch size, trg len, output dim] | |
#attention = [batch size, n heads, trg len, src len] | |
return output, attention | |
# if __name__ == "__main__": | |
# set_trace() | |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to( | |
# device | |
# ) | |
# trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device) | |
# src_pad_idx = 0 | |
# trg_pad_idx = 0 | |
# src_vocab_size = 10 | |
# trg_vocab_size = 10 | |
# model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx).to( | |
# device | |
# ) | |
# out, attention = model(x, trg[:, :-1]) | |
# print(out.shape) | |
# INPUT_DIM = len(SRC.vocab) | |
# OUTPUT_DIM = len(TRG.vocab) | |
if __name__ == "__main__": | |
pass | |
if __name__ == "__main__": | |
# set_trace() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
INPUT_DIM = 10 | |
OUTPUT_DIM = 10 | |
HID_DIM = 256 | |
ENC_LAYERS = 3 | |
DEC_LAYERS = 3 | |
ENC_HEADS = 8 | |
DEC_HEADS = 8 | |
ENC_PF_DIM = 512 | |
DEC_PF_DIM = 512 | |
ENC_DROPOUT = 0.1 | |
DEC_DROPOUT = 0.1 | |
SRC_PAD_IDX = 0 | |
TRG_PAD_IDX = 0 | |
enc = Encoder(INPUT_DIM, | |
HID_DIM, | |
ENC_LAYERS, | |
ENC_HEADS, | |
ENC_PF_DIM, | |
ENC_DROPOUT, | |
device) | |
dec = Decoder(OUTPUT_DIM, | |
HID_DIM, | |
DEC_LAYERS, | |
DEC_HEADS, | |
DEC_PF_DIM, | |
DEC_DROPOUT, | |
device) | |
src_vocab_size = 10 | |
trg_vocab_size = 10 | |
x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2, 0]]).to( | |
device, dtype=torch.int64 | |
) | |
trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0,0,0], [1, 5, 6, 2, 4, 7, 6, 2,0,0]]).to(device, dtype=torch.int64) | |
model = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to( | |
device | |
) | |
out, attention = model(x, trg[:, :-1]) | |
print(out.shape) | |
print(attention.shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment