Last active
March 21, 2019 02:40
-
-
Save yusugomori/29244ef1804891c202a044f93aaf4433 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import subprocess | |
import random | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.optim as optimizers | |
from torch.autograd import Variable | |
class PositionalEncoding(nn.Module): | |
''' | |
Positional encoding layer with sinusoid | |
''' | |
def __init__(self, output_dim, | |
max_len=6000, | |
device='cpu'): | |
super().__init__() | |
self.output_dim = output_dim | |
self.max_len = max_len | |
pe = self.initializer() | |
self.register_buffer('pe', pe) | |
def forward(self, x, mask=None): | |
''' | |
# Argument | |
x: (batch, sequence) | |
''' | |
pe = self.pe[:x.size(1), :].unsqueeze(0) | |
return x + Variable(pe, requires_grad=False) | |
def initializer(self): | |
pe = \ | |
np.array([[pos / np.power(10000, 2 * (i // 2) / self.output_dim) | |
for i in range(self.output_dim)] | |
for pos in range(self.max_len)]) | |
pe[:, 0::2] = np.sin(pe[:, 0::2]) | |
pe[:, 1::2] = np.cos(pe[:, 1::2]) | |
return torch.from_numpy(pe).float() | |
class ScaledDotProductAttention(nn.Module): | |
def __init__(self, | |
d_k, | |
device='cpu'): | |
super().__init__() | |
self.device = device | |
self.scaler = np.sqrt(d_k) | |
def forward(self, q, k, v, mask=None): | |
''' | |
# Argument | |
q, k, v: (batch, sequence, out_features) | |
mask: (batch, sequence (, sequence)) | |
''' | |
score = torch.einsum('ijk,ilk->ijl', (q, k)) / self.scaler | |
score = score - torch.max(score, | |
dim=-1, | |
keepdim=True)[0] # softmax max trick | |
score = torch.exp(score) | |
if mask is not None: | |
# suppose `mask` is a mask of source | |
# in source-target-attention, source is `k` and `v` | |
if len(mask.size()) == 2: | |
mask = mask.unsqueeze(1).repeat(1, score.size(1), 1) | |
# score = score * mask.float().to(self.device) | |
score = score.data.masked_fill_(mask, 0) | |
a = score / torch.sum(score, dim=-1, keepdim=True) | |
c = torch.einsum('ijk,ikl->ijl', (a, v)) | |
return c | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, | |
h, | |
d_model, | |
device='cpu'): | |
super().__init__() | |
self.h = h | |
self.d_model = d_model | |
self.d_k = d_k = d_model // h | |
self.d_v = d_v = d_model // h | |
self.device = device | |
self.W_q = nn.Parameter(torch.Tensor(h, | |
d_model, | |
d_k)) | |
self.W_k = nn.Parameter(torch.Tensor(h, | |
d_model, | |
d_k)) | |
self.W_v = nn.Parameter(torch.Tensor(h, | |
d_model, | |
d_v)) | |
nn.init.xavier_normal_(self.W_q) | |
nn.init.xavier_normal_(self.W_k) | |
nn.init.xavier_normal_(self.W_v) | |
self.attn = ScaledDotProductAttention(d_k) | |
self.linear = nn.Linear((h * d_v), d_model) | |
nn.init.xavier_normal_(self.linear.weight) | |
def forward(self, q, k, v, mask=None): | |
''' | |
# Argument | |
q, k, v: (batch, sequence, out_features) | |
mask: (batch, sequence (, sequence)) | |
''' | |
batch_size = q.size(0) | |
q = torch.einsum('hijk,hkl->hijl', | |
(q.unsqueeze(0).repeat(self.h, 1, 1, 1), | |
self.W_q)) | |
k = torch.einsum('hijk,hkl->hijl', | |
(k.unsqueeze(0).repeat(self.h, 1, 1, 1), | |
self.W_k)) | |
v = torch.einsum('hijk,hkl->hijl', | |
(v.unsqueeze(0).repeat(self.h, 1, 1, 1), | |
self.W_v)) | |
q = q.view(-1, q.size(-2), q.size(-1)) | |
k = k.view(-1, k.size(-2), k.size(-1)) | |
v = v.view(-1, v.size(-2), v.size(-1)) | |
if mask is not None: | |
multiples = [self.h] + [1] * (len(mask.size()) - 1) | |
mask = mask.repeat(multiples) | |
c = self.attn(q, k, v, mask=mask) | |
c = torch.split(c, batch_size, dim=0) | |
c = torch.cat(c, dim=-1) | |
out = self.linear(c) | |
return out | |
class Encoder(nn.Module): | |
def __init__(self, | |
depth_source, | |
N=6, | |
h=8, | |
d_model=512, | |
d_ff=2048, | |
p_dropout=0.1, | |
max_len=128, | |
device='cpu'): | |
super().__init__() | |
self.device = device | |
self.embedding = nn.Embedding(depth_source, | |
d_model, padding_idx=0) | |
self.pe = PositionalEncoding(d_model, max_len=max_len) | |
self.encs = nn.ModuleList([ | |
EncoderLayer(h=h, | |
d_model=d_model, | |
d_ff=d_ff, | |
p_dropout=p_dropout, | |
max_len=max_len, | |
device=device) for _ in range(N)]) | |
def forward(self, x, mask=None): | |
x = self.embedding(x) | |
y = self.pe(x) | |
for enc in self.encs: | |
y = enc(y, mask=mask) | |
return y | |
class EncoderLayer(nn.Module): | |
def __init__(self, | |
h=8, | |
d_model=512, | |
d_ff=2048, | |
p_dropout=0.1, | |
max_len=128, | |
device='cpu'): | |
super().__init__() | |
self.attn = MultiHeadAttention(h, d_model) | |
self.dropout1 = nn.Dropout(p_dropout) | |
self.norm1 = nn.LayerNorm(d_model) | |
self.ff = FFN(d_model, d_ff) | |
self.dropout2 = nn.Dropout(p_dropout) | |
self.norm2 = nn.LayerNorm(d_model) | |
def forward(self, x, mask=None): | |
h = self.attn(x, x, x, mask=mask) | |
h = self.dropout1(h) | |
h = self.norm1(x + h) | |
y = self.ff(h) | |
y = self.dropout2(y) | |
y = self.norm2(h + y) | |
return y | |
class Decoder(nn.Module): | |
def __init__(self, | |
depth_target, | |
N=6, | |
h=8, | |
d_model=512, | |
d_ff=2048, | |
p_dropout=0.1, | |
max_len=128, | |
device='cpu'): | |
super().__init__() | |
self.device = device | |
self.embedding = nn.Embedding(depth_target, | |
d_model, padding_idx=0) | |
self.pe = PositionalEncoding(d_model, max_len=max_len) | |
self.decs = nn.ModuleList([ | |
DecoderLayer(h=h, | |
d_model=d_model, | |
d_ff=d_ff, | |
p_dropout=p_dropout, | |
max_len=max_len, | |
device=device) for _ in range(N)]) | |
def forward(self, x, hs, | |
mask=None, | |
source_mask=None): | |
x = self.embedding(x) | |
y = self.pe(x) | |
for dec in self.decs: | |
y = dec(y, hs, | |
mask=mask, | |
source_mask=source_mask) | |
return y | |
class DecoderLayer(nn.Module): | |
def __init__(self, | |
h=8, | |
d_model=512, | |
d_ff=2048, | |
p_dropout=0.1, | |
max_len=128, | |
device='cpu'): | |
super().__init__() | |
self.self_attn = MultiHeadAttention(h, d_model) | |
self.dropout1 = nn.Dropout(p_dropout) | |
self.norm1 = nn.LayerNorm(d_model) | |
self.src_tgt_attn = MultiHeadAttention(h, d_model) | |
self.dropout2 = nn.Dropout(p_dropout) | |
self.norm2 = nn.LayerNorm(d_model) | |
self.ff = FFN(d_model, d_ff) | |
self.dropout3 = nn.Dropout(p_dropout) | |
self.norm3 = nn.LayerNorm(d_model) | |
def forward(self, x, hs, | |
mask=None, | |
source_mask=None): | |
h = self.self_attn(x, x, x, mask=mask) | |
h = self.dropout1(h) | |
h = self.norm1(x + h) | |
z = self.src_tgt_attn(h, hs, hs, | |
mask=source_mask) | |
z = self.dropout2(z) | |
z = self.norm2(h + z) | |
y = self.ff(z) | |
y = self.dropout3(y) | |
y = self.norm3(z + y) | |
return y | |
class FFN(nn.Module): | |
''' | |
Position-wise Feed-Forward Networks | |
''' | |
def __init__(self, d_model, d_ff, | |
device='cpu'): | |
super().__init__() | |
self.l1 = nn.Linear(d_model, d_ff) | |
self.l2 = nn.Linear(d_ff, d_model) | |
# self.l1 = nn.Conv1d(d_model, d_ff, 1) | |
# self.l2 = nn.Conv1d(d_ff, d_model, 1) | |
def forward(self, x): | |
x = self.l1(x) | |
x = torch.relu(x) | |
y = self.l2(x) | |
return y | |
class Transformer(nn.Module): | |
def __init__(self, | |
depth_source, | |
depth_target, | |
N=6, | |
h=8, | |
d_model=512, | |
d_ff=2048, | |
p_dropout=0.1, | |
max_len=20, | |
bos_value=1, | |
device='cpu'): | |
super().__init__() | |
self.device = device | |
self.encoder = Encoder(depth_source, | |
N=N, | |
h=h, | |
d_model=d_model, | |
d_ff=d_ff, | |
p_dropout=p_dropout, | |
max_len=max_len, | |
device=device) | |
self.decoder = Decoder(depth_target, | |
N=N, | |
h=h, | |
d_model=d_model, | |
d_ff=d_ff, | |
p_dropout=p_dropout, | |
max_len=max_len, | |
device=device) | |
self.out = nn.Linear(d_model, depth_target) | |
nn.init.xavier_normal_(self.out.weight) | |
self._BOS = bos_value | |
self._max_len = max_len | |
def forward(self, source, target=None): | |
source_mask = self.sequence_mask(source) | |
hs = self.encoder(source, mask=source_mask) | |
if target is not None: | |
len_target_sequences = target.size(1) | |
target_mask = self.sequence_mask(target).unsqueeze(1) | |
subsequent_mask = self.subsequence_mask(target) | |
target_mask = torch.gt(target_mask + subsequent_mask, 0) | |
y = self.decoder(target, hs, | |
mask=target_mask, | |
source_mask=source_mask) | |
output = self.out(y) | |
else: | |
batch_size = source.size(0) | |
len_target_sequences = self._max_len | |
output = torch.ones((batch_size, 1), | |
dtype=torch.long, | |
device=device) * self._BOS | |
for t in range(len_target_sequences - 1): | |
target_mask = self.subsequence_mask(output) | |
out = self.decoder(output, hs, | |
mask=target_mask, | |
source_mask=source_mask) | |
out = self.out(out)[:, -1:, :] | |
out = out.max(-1)[1] | |
output = torch.cat((output, out), dim=1) | |
return output | |
def sequence_mask(self, x): | |
return x.eq(0) | |
def subsequence_mask(self, x): | |
shape = (x.size(1), x.size(1)) | |
mask = torch.triu(torch.ones(shape, dtype=torch.uint8), | |
diagonal=1) | |
return mask.unsqueeze(0).repeat(x.size(0), 1, 1).to(self.device) | |
def load_small_parallel_enja(path=None, | |
to_ja=True, | |
pad_value=0, | |
start_char=1, | |
end_char=2, | |
oov_char=3, | |
index_from=4, | |
pad='<PAD>', | |
bos='<BOS>', | |
eos='<EOS>', | |
oov='<UNK>', | |
add_bos=True, | |
add_eos=True): | |
''' | |
Download 50k En/Ja Parallel Corpus | |
from https://github.com/odashi/small_parallel_enja | |
and transform words to IDs. | |
Original Source from: | |
https://github.com/yusugomori/tftf/blob/master/tftf/datasets/small_parallel_enja.py | |
''' | |
url_base = 'https://raw.githubusercontent.com/' \ | |
'odashi/small_parallel_enja/master/' | |
path = path or 'small_parallel_enja' | |
dir_path = os.path.join(os.path.expanduser('~'), | |
'.tftf', 'datasets', path) | |
if not os.path.exists(dir_path): | |
os.makedirs(dir_path) | |
f_ja = ['train.ja', 'test.ja'] | |
f_en = ['train.en', 'test.en'] | |
for f in (f_ja + f_en): | |
f_path = os.path.join(dir_path, f) | |
if not os.path.exists(f_path): | |
url = url_base + f | |
print('Downloading {}'.format(f)) | |
cmd = ['curl', '-o', f_path, url] | |
subprocess.call(cmd) | |
f_ja_train = os.path.join(dir_path, f_ja[0]) | |
f_test_ja = os.path.join(dir_path, f_ja[1]) | |
f_en_train = os.path.join(dir_path, f_en[0]) | |
f_test_en = os.path.join(dir_path, f_en[1]) | |
(ja_train, test_ja), num_words_ja, (w2i_ja, i2w_ja) = \ | |
_build(f_ja_train, f_test_ja, | |
pad_value, start_char, end_char, oov_char, index_from, | |
pad, bos, eos, oov, add_bos, add_eos) | |
(en_train, test_en), num_words_en, (w2i_en, i2w_en) = \ | |
_build(f_en_train, f_test_en, | |
pad_value, start_char, end_char, oov_char, index_from, | |
pad, bos, eos, oov, add_bos, add_eos) | |
if to_ja: | |
x_train, x_test, num_X, w2i_X, i2w_X = \ | |
en_train, test_en, num_words_en, w2i_en, i2w_en | |
y_train, y_test, num_y, w2i_y, i2w_y = \ | |
ja_train, test_ja, num_words_ja, w2i_ja, i2w_ja | |
else: | |
x_train, x_test, num_X, w2i_X, i2w_X = \ | |
ja_train, test_ja, num_words_ja, w2i_ja, i2w_ja | |
y_train, y_test, num_y, w2i_y, i2w_y = \ | |
en_train, test_en, num_words_en, w2i_en, i2w_en | |
x_train, x_test = np.array(x_train), np.array(x_test) | |
y_train, y_test = np.array(y_train), np.array(y_test) | |
return (x_train, y_train), (x_test, y_test), \ | |
(num_X, num_y), (w2i_X, w2i_y), (i2w_X, i2w_y) | |
def _build(f_train, f_test, | |
pad_value=0, | |
start_char=1, | |
end_char=2, | |
oov_char=3, | |
index_from=4, | |
pad='<PAD>', | |
bos='<BOS>', | |
eos='<EOS>', | |
oov='<UNK>', | |
add_bos=True, | |
add_eos=True): | |
builder = _Builder(pad_value=pad_value, | |
start_char=start_char, | |
end_char=end_char, | |
oov_char=oov_char, | |
index_from=index_from, | |
pad=pad, | |
bos=bos, | |
eos=eos, | |
oov=oov, | |
add_bos=add_bos, | |
add_eos=add_eos) | |
builder.fit(f_train) | |
train = builder.transform(f_train) | |
test = builder.transform(f_test) | |
return (train, test), builder.num_words, (builder.w2i, builder.i2w) | |
class _Builder(object): | |
def __init__(self, | |
pad_value=0, | |
start_char=1, | |
end_char=2, | |
oov_char=3, | |
index_from=4, | |
pad='<PAD>', | |
bos='<BOS>', | |
eos='<EOS>', | |
oov='<UNK>', | |
add_bos=True, | |
add_eos=True): | |
self._vocab = None | |
self._w2i = None | |
self._i2w = None | |
self.pad_value = pad_value | |
self.start_char = start_char | |
self.end_char = end_char | |
self.oov_char = oov_char | |
self.index_from = index_from | |
self.pad = pad | |
self.bos = bos | |
self.eos = eos | |
self.oov = oov | |
self.add_bos = add_bos | |
self.add_eos = add_eos | |
@property | |
def num_words(self): | |
return max(self._w2i.values()) + 1 | |
@property | |
def w2i(self): | |
''' | |
Dict of word to index | |
''' | |
return self._w2i | |
@property | |
def i2w(self): | |
''' | |
Dict of index to word | |
''' | |
return self._i2w | |
def fit(self, f_path): | |
self._vocab = set() | |
self._w2i = {} | |
for line in open(f_path, encoding='utf-8'): | |
_sentence = line.strip().split() | |
self._vocab.update(_sentence) | |
self._w2i = {w: (i + self.index_from) | |
for i, w in enumerate(self._vocab)} | |
if self.pad_value >= 0: | |
self._w2i[self.pad] = self.pad_value | |
self._w2i[self.bos] = self.start_char | |
self._w2i[self.eos] = self.end_char | |
self._w2i[self.oov] = self.oov_char | |
self._i2w = {i: w for w, i in self._w2i.items()} | |
def transform(self, f_path): | |
if self._vocab is None or self._w2i is None: | |
raise AttributeError('`{}.fit` must be called before `transform`.' | |
''.format(self.__class__.__name__)) | |
sentences = [] | |
for line in open(f_path, encoding='utf-8'): | |
_sentence = line.strip().split() | |
# _sentence = [self.bos] + _sentence + [self.eos] | |
if self.add_bos: | |
_sentence = [self.bos] + _sentence | |
if self.add_eos: | |
_sentence = _sentence + [self.eos] | |
sentences.append(self._encode(_sentence)) | |
return sentences | |
def _encode(self, sentence): | |
encoded = [] | |
for w in sentence: | |
if w not in self._w2i: | |
id = self.oov_char | |
else: | |
id = self._w2i[w] | |
encoded.append(id) | |
return encoded | |
def pad_sequences(data, | |
padding='pre', | |
value=0): | |
''' | |
# Arguments | |
data: list of lists / np.array of lists | |
# Returns | |
numpy.ndarray | |
''' | |
if type(data[0]) is not list: | |
raise ValueError('`data` must be a list of lists') | |
maxlen = len(max(data, key=len)) | |
if padding == 'pre': | |
data = \ | |
[[value] * (maxlen - len(data[i])) + data[i] | |
for i in range(len(data))] | |
elif padding == 'post': | |
data = \ | |
[data[i] + [value] * (maxlen - len(data[i])) | |
for i in range(len(data))] | |
else: | |
raise ValueError('`padding` must be one of \'pre\' or \'post\'') | |
return np.array(data) | |
def sort(data, target, | |
order='ascend'): | |
if order == 'ascend' or order == 'ascending': | |
a = True | |
elif order == 'descend' or order == 'descending': | |
a = False | |
else: | |
raise ValueError('`order` must be of \'ascend\' or \'descend\'.') | |
lens = [len(i) for i in data] | |
indices = sorted(range(len(lens)), | |
key=lambda x: (2 * a - 1) * lens[x]) | |
data = [data[i] for i in indices] | |
target = [target[i] for i in indices] | |
return (data, target) | |
if __name__ == '__main__': | |
np.random.seed(1234) | |
torch.manual_seed(1234) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def compute_loss(label, pred): | |
return criterion(pred, label) | |
def train_step(x, t): | |
model.train() | |
preds = model(x, t) | |
loss = compute_loss(t.contiguous().view(-1), | |
preds.contiguous().view(-1, preds.size(-1))) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
return loss, preds | |
def valid_step(x, t): | |
model.eval() | |
preds = model(x, t) | |
loss = compute_loss(t.contiguous().view(-1), | |
preds.contiguous().view(-1, preds.size(-1))) | |
return loss, preds | |
def test_step(x): | |
model.eval() | |
preds = model(x) | |
return preds | |
def ids_to_sentence(ids, i2w): | |
return [i2w[id] for id in ids] | |
''' | |
Load data | |
''' | |
class ParallelDataLoader(object): | |
def __init__(self, dataset, | |
batch_size=128, | |
shuffle=False, | |
random_state=None): | |
if type(dataset) is not tuple: | |
raise ValueError('argument `dataset` must be tuple,' | |
' not {}.'.format(type(dataset))) | |
self.dataset = list(zip(dataset[0], dataset[1])) | |
self.batch_size = batch_size | |
self.shuffle = shuffle | |
if random_state is None: | |
random_state = np.random.RandomState(1234) | |
self.random_state = random_state | |
self._idx = 0 | |
def __len__(self): | |
return len(self.dataset) | |
def __iter__(self): | |
return self | |
def __next__(self): | |
if self._idx >= len(self.dataset): | |
self._reorder() | |
raise StopIteration() | |
x, y = zip(*self.dataset[self._idx:(self._idx + self.batch_size)]) | |
x, y = sort(x, y, order='descend') | |
x = pad_sequences(x, padding='post') | |
y = pad_sequences(y, padding='post') | |
x = torch.LongTensor(x) # not use .t() | |
y = torch.LongTensor(y) # not use .t() | |
self._idx += self.batch_size | |
return x, y | |
def _reorder(self): | |
if self.shuffle: | |
self.data = shuffle(self.dataset, | |
random_state=self.random_state) | |
self._idx = 0 | |
(x_train, y_train), \ | |
(x_test, y_test), \ | |
(num_x, num_y), \ | |
(w2i_x, w2i_y), (i2w_x, i2w_y) = \ | |
load_small_parallel_enja(to_ja=True) | |
train_dataloader = ParallelDataLoader((x_train, y_train), | |
shuffle=True) | |
valid_dataloader = ParallelDataLoader((x_test, y_test)) | |
test_dataloader = ParallelDataLoader((x_test, y_test), | |
batch_size=1, | |
shuffle=True) | |
''' | |
Build model | |
''' | |
model = Transformer(num_x, | |
num_y, | |
N=3, | |
h=4, | |
d_model=128, | |
d_ff=256, | |
max_len=20, | |
device=device).to(device) | |
criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=0) | |
optimizer = optimizers.Adam(model.parameters()) | |
''' | |
Train model | |
''' | |
epochs = 20 | |
for epoch in range(epochs): | |
print('-' * 20) | |
print('Epoch: {}'.format(epoch+1)) | |
train_loss = 0. | |
valid_loss = 0. | |
for idx, (source, target) in enumerate(train_dataloader): | |
source, target = source.to(device), target.to(device) | |
loss, _ = train_step(source, target) | |
train_loss += loss.item() | |
train_loss /= len(train_dataloader) | |
for (source, target) in valid_dataloader: | |
source, target = source.to(device), target.to(device) | |
loss, _ = valid_step(source, target) | |
valid_loss += loss.item() | |
valid_loss /= len(valid_dataloader) | |
print('Valid loss: {:.3}'.format(valid_loss)) | |
for idx, (source, target) in enumerate(test_dataloader): | |
source, target = source.to(device), target.to(device) | |
out = test_step(source) | |
out = out.view(-1).tolist() | |
out = ' '.join(ids_to_sentence(out, i2w_y)) | |
source = ' '.join(ids_to_sentence(source.view(-1).tolist(), i2w_x)) | |
target = ' '.join(ids_to_sentence(target.view(-1).tolist(), i2w_y)) | |
print('>', source) | |
print('=', target) | |
print('<', out) | |
print() | |
if idx >= 10: | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
NOTICE: This implementation is not working.
(Referenced on stackoverflow.)