Last active
June 14, 2022 23:30
-
-
Save cwarny/01907e221e01ea9b140e2a86a7070cf7 to your computer and use it in GitHub Desktop.
Software 2.0
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Acc: | |
def __init__(self, ignore_index=1): | |
self.ignore_index = ignore_index | |
def __call__(self, pred, tgt): | |
# both pred and tgt have shape (bs,seq_len) | |
mask = tgt != self.ignore_index | |
pred *= mask | |
tgt *= mask | |
correct = torch.eq(pred, tgt).all(1).sum() | |
return correct.item() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
assert roman_to_integer('MCXCIII') == 1193 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
max_len = 20 | |
vocab = build_vocab_from_iterator( | |
list(map(str,range(10))) # Arabic symbols | |
+ ['I','V','X','L','C','D','M'], # Roman symbols | |
specials=['<bos>', '<pad>', '<eos>'] # Special symbols | |
) | |
vocab_size = len(vocab) | |
pad_idx = vocab['<pad>'] | |
collate_fn = partial(collate, pad_idx=pad_idx, max_len=max_len) | |
proc = Processor(vocab) | |
train_ds = NumberDataset.from_file('train', processor=proc) | |
valid_ds = NumberDataset.from_file('valid', processor=proc) | |
train_dl = DataLoader(train_ds, batch_size=10, collate_fn=collate_fn) | |
valid_dl = DataLoader(valid_ds, batch_size=10, collate_fn=collate_fn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torchtext.vocab import build_vocab_from_iterator | |
from torch.utils.data import Dataset, DataLoader | |
from torch.nn.utils.rnn import pad_sequence | |
from functools import partial | |
class NumberDataset(Dataset): | |
valid_targets = ['roman','integer'] | |
def __init__(self, lst, processor=None, target='roman'): | |
assert target in self.valid_targets, f'Target needs to be one of {self.valid_targets}' | |
self.target = target | |
self.processor = processor | |
self.lst = lst | |
def __getitem__(self, i): | |
i,r = self.lst[i] | |
t = (i,r) if self.target == 'roman' else (r,i) | |
return list(map(self.processor.process, t)) | |
def __len__(self): | |
return len(self.lst) | |
@classmethod | |
def from_file(cls, fn, root=None, extension='txt', **kwargs): | |
if root is None: root = default_data_path | |
url = root/('.'.join([fn,extension])) | |
with open(url) as f: lines = [line.split() for line in f] | |
return cls(lines, **kwargs) | |
class Processor: | |
def __init__(self, vocab): | |
self.vocab = vocab | |
def process(self, x): | |
seq = ['<bos>'] + list(x) + ['<eos>'] | |
return [self.vocab[tok] for tok in seq] | |
def deprocess(self, x): | |
out = [] | |
for idx in x: | |
tok = self.vocab.lookup_token(idx) | |
if tok == '<bos>': continue | |
elif tok == '<eos>': return ''.join(out) | |
else: out.append(tok) | |
return ''.join(out) | |
def collate(batch, max_len=20, pad_idx=1): | |
src_lst, tgt_lst = [], [] | |
for src, tgt in batch: | |
src, tgt = map(torch.tensor, [src, tgt]) | |
src_lst.append(src) | |
tgt_lst.append(tgt) | |
src_lst[0] = nn.ConstantPad1d((0, max_len-src_lst[0].size(0)), pad_idx)(src_lst[0]) | |
tgt_lst[0] = nn.ConstantPad1d((0, max_len-tgt_lst[0].size(0)), pad_idx)(tgt_lst[0]) | |
return list(map(partial(pad_sequence, padding_value=pad_idx, batch_first=True), [src_lst, tgt_lst])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def evaluate(mdl, dl, loss_fn, metric): | |
mdl.eval() | |
epoch_loss = 0 | |
correct = 0 | |
with torch.no_grad(): | |
for i, (src, tgt) in enumerate(dl): | |
out = mdl(src, tgt, teacher_forcing_proba=0) # turn off teacher forcing | |
bs, seq_len, out_dim = out.shape | |
out = out.view(-1, out_dim) | |
tgt = tgt[:,1:].contiguous().view(-1) | |
loss = loss_fn(out, tgt) | |
epoch_loss += loss.item() | |
pred = out.argmax(-1) | |
m = metric(pred.view(bs, -1), tgt.view(bs, -1)) | |
correct += m | |
n = (i+1)*bs | |
return epoch_loss/n, correct/n |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
best_i2r_mdl = fit(100, i2r_mdl, train_dl, valid_dl, opt, criterion, metric, patience=3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
def fit(epochs, mdl, train_dl, valid_dl, opt, criterion, metric, patience=2): | |
fmt = lambda x: f'{x:.3f}' | |
best_valid_loss = float('inf') | |
best_mdl = None | |
irritation = 0 | |
for epoch in range(epochs): | |
print(f'Epoch: {epoch+1:02}') | |
train_loss, train_metric = train(mdl, train_dl, opt, criterion, metric) | |
valid_loss, valid_metric = evaluate(mdl, valid_dl, criterion, metric) | |
print('\t' + json.dumps({ | |
'train': { | |
'loss': fmt(train_loss), | |
'metric': fmt(train_metric) | |
}, | |
'valid': { | |
'loss': fmt(valid_loss), | |
'metric': fmt(valid_metric) | |
} | |
}, indent=4)) | |
if valid_loss < best_valid_loss: | |
best_valid_loss = valid_loss | |
best_mdl = mdl | |
torch.save(mdl.state_dict(), 'model.pt') | |
irritation = 0 | |
else: | |
irritation += 1 | |
if irritation == patience: break | |
return best_mdl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def integer_to_roman(n): | |
div = 1 | |
while n >= div: div *= 10 | |
div //= 10 | |
out = [] | |
while n: | |
d = n // div # get most significant digit via floor division by a power of 10 | |
if d < 4: | |
o = i2r[div]*d | |
elif d == 4: | |
o = i2r[div] + i2r[div*5] | |
elif d < 9: | |
o = i2r[div*5] + (d-5)*i2r[div] | |
else: | |
o = i2r[div] + i2r[div*10] | |
out.append(o) | |
n = n % div # the new integer is the remainder | |
div //= 10 | |
return ''.join(out) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
r2i = { | |
'I': 1, | |
'V': 5, | |
'X': 10, | |
'L': 50, | |
'C': 100, | |
'D': 500, | |
'M': 1000 | |
} | |
i2r = {v:k for k,v in r2i.items()} # reverse the mapping |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import torch.nn.functional as F | |
class Encoder(nn.Module): | |
def __init__(self, vocab_size, hidden_dim, dropout): | |
super().__init__() | |
self.emb = nn.Embedding(vocab_size, hidden_dim) | |
self.rnn = nn.GRU(hidden_dim, hidden_dim, batch_first=True, dropout=dropout, bidirectional=True) | |
self.dropout = nn.Dropout(dropout) | |
self.project = nn.Linear(hidden_dim*2, hidden_dim) | |
def forward(self, x): | |
bs = x.size(0) | |
x = self.dropout(self.emb(x)) | |
h, h_last = self.rnn(x) | |
h_last = h_last.permute(1,0,2).contiguous().view(bs, -1) # (bs,hidden_dim*2) | |
h, h_last = map(self.project, [h, h_last]) | |
h_last = h_last.unsqueeze(0) | |
return h, h_last # (bs,seq_len,hidden_dim), (1,bs,hidden_dim) | |
class Attention(nn.Module): | |
def __init__(self, encoder_hidden_dim, decoder_hidden_dim): | |
super().__init__() | |
self.decoder_hidden_dim = torch.tensor(decoder_hidden_dim) | |
self.w = nn.Parameter(torch.FloatTensor(decoder_hidden_dim, encoder_hidden_dim).uniform_(-0.1, 0.1)) | |
def forward(self, query, values): | |
score = (query.unsqueeze(1) @ self.w @ values.permute(0,2,1))/torch.sqrt(self.decoder_hidden_dim) | |
attention_weights = F.softmax(score, 1) | |
context = attention_weights @ values | |
return context | |
class Decoder(nn.Module): | |
def __init__(self, vocab_size, hidden_dim, dropout): | |
super().__init__() | |
self.emb = nn.Embedding(vocab_size, hidden_dim) | |
self.rnn = nn.GRU(hidden_dim*2, hidden_dim, batch_first=True, dropout=dropout) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, y, h_in, context): | |
y = self.dropout(self.emb(y.unsqueeze(1))) | |
y = torch.cat([context, y], -1) # (bs,1,hidden_dim*2) | |
h, h_last = self.rnn(y, h_in) # (bs,1,hidden_dim), (1,bs,hidden_dim) | |
return h.squeeze(1), h_last | |
class Seq2SeqWithAttention(nn.Module): | |
def __init__(self, vocab_size, hidden_dim=20, dropout=.5): | |
super().__init__() | |
self.encode = Encoder(vocab_size, hidden_dim, dropout) | |
self.attend = Attention(hidden_dim, hidden_dim) | |
self.decode = Decoder(vocab_size, hidden_dim, dropout) | |
self.project = nn.Linear(hidden_dim, vocab_size) | |
def forward(self, src, tgt, teacher_forcing_proba=.5): | |
bs, tgt_len = tgt.shape | |
h, h_last = self.encode(src) | |
s = h_last.squeeze(0) # (bs,hidden_dim) | |
y = tgt[:,0] | |
logits = [] | |
for t in range(1, tgt_len): | |
context = self.attend(s, h) # context: (bs,1,hidden_dim) | |
s, h_last = self.decode(y, h_last, context) | |
logit = self.project(s) | |
logits.append(logit) | |
teacher_force = random.random() < teacher_forcing_proba | |
y = tgt[:,t] if teacher_force else logit.argmax(-1) | |
return torch.stack(logits, 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.optim as optim | |
i2r_mdl = Seq2SeqWithAttention(vocab_size, hidden_dim=30, dropout=.3) | |
opt = optim.Adam(i2r_mdl.parameters(), lr=1e-3) | |
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx) | |
metric = Acc(ignore_index=pad_idx) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
to_predict = [i for i,r in test_set] | |
preds = predict(best_i2r_mdl, to_predict, proc, collate_fn) | |
for (i,r),pred in zip(test_set,preds): | |
print(f'{i} -> {pred} ({r})') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
to_predict = [r for i,r in test_set] | |
preds = predict(best_r2i_mdl, to_predict, proc, collate_fn) | |
for (i,r),pred in zip(test_set,preds): | |
print(f'{r} -> {pred} ({i})') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def predict(mdl, tests, proc, collate): | |
mdl.eval() | |
tests = list(zip(tests,tests)) | |
test_ds = NumberDataset(tests, processor=proc) | |
test_dl = DataLoader(test_ds, batch_size=len(tests), collate_fn=collate) | |
with torch.no_grad(): | |
for src,tgt in test_dl: | |
out = mdl(src, tgt, teacher_forcing_proba=0) | |
bs, seq_len, out_dim = out.shape | |
out = out.view(-1, out_dim) | |
tgt = tgt[:,1:].contiguous().view(-1) | |
pred = out.argmax(-1) | |
return [proc.deprocess(seq) for seq in pred.view(bs, -1)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
train_ds = NumberDataset.from_file('train', processor=proc, target='integer') | |
valid_ds = NumberDataset.from_file('valid', processor=proc, target='integer') | |
train_dl = DataLoader(train_ds, batch_size=10, collate_fn=collate_fn) | |
valid_dl = DataLoader(valid_ds, batch_size=10, collate_fn=collate_fn) | |
r2i_mdl = Seq2SeqWithAttention(vocab_size, hidden_dim=30, dropout=.3) | |
opt = optim.Adam(r2i_mdl.parameters(), lr=1e-3) | |
best_r2i_mdl = fit(100, r2i_mdl, train_dl, valid_dl, opt, criterion, metric, patience=3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def roman_to_integer(s): | |
l = len(s) | |
tot = 0 | |
prev_n = 0 | |
for i in range(l): | |
current_n = r2i[s[i]] | |
next_n = r2i[s[i+1]] if i+1 < l else 0 | |
if current_n >= next_n: | |
tot += (current_n - prev_n) | |
prev_n = 0 | |
else: | |
prev_n = current_n | |
return tot |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
test_ds = NumberDataset.from_file('test', processor=proc) | |
test_dl = DataLoader(test_ds, batch_size=5, collate_fn=collate_fn) | |
_, test_metric = evaluate(best_i2r_mdl, test_dl, criterion, metric) | |
print(test_metric) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
test_ds = NumberDataset.from_file('test', processor=proc, target='integer') | |
test_dl = DataLoader(test_ds, batch_size=5, collate_fn=collate_fn) | |
_, test_metric = evaluate(best_r2i_mdl, test_dl, criterion, metric) | |
print(test_metric) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
test_set = [ | |
('4', 'IV'), | |
('1193', 'MCXCIII'), | |
('548', 'DXLVIII'), | |
('3616', 'MMMDCXVI'), | |
('21', 'XXI') | |
] | |
for src,tgt in test_set: | |
assert integer_to_roman(int(src)) == tgt | |
print('Success') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(mdl, dl, opt, loss_fn, metric): | |
mdl.train() | |
epoch_loss = 0 | |
correct = 0 | |
for i, (src, tgt) in enumerate(dl): | |
opt.zero_grad() | |
out = mdl(src, tgt) | |
bs, seq_len, out_dim = out.shape | |
assert out.size(1) == tgt.size(1)-1 # we skipped the first element in the output | |
# collapse seq and batch dims | |
out = out.view(-1, out_dim) | |
tgt = tgt[:,1:].contiguous().view(-1) # skip the first element in the ground truth | |
loss = loss_fn(out, tgt) | |
loss.backward() | |
opt.step() | |
epoch_loss += loss.item() | |
pred = out.argmax(-1) | |
m = metric(pred.view(bs, -1), tgt.view(bs, -1)) | |
correct += m | |
if i > 0 and i % 1e4 == 0: | |
print(f'\t{i}: {epoch_loss/i:.3f}') | |
n = (i+1)*bs | |
return epoch_loss/n, correct/n |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment