Skip to content

Instantly share code, notes, and snippets.

@movefast
Created April 21, 2018 06:37
Show Gist options
  • Save movefast/7d3a8f60fb970221bba02a1ecf732d47 to your computer and use it in GitHub Desktop.
Save movefast/7d3a8f60fb970221bba02a1ecf732d47 to your computer and use it in GitHub Desktop.
Seq2Seq rnn with beam search
class Seq2SeqAttnRNN(nn.Module):
def __init__(self, vecs_enc, itos_enc, em_sz_enc, vecs_dec, itos_dec, em_sz_dec, nh, out_sl, nl=2):
super().__init__()
self.emb_enc = create_emb(vecs_enc, itos_enc, em_sz_enc)
self.nl,self.nh,self.out_sl = nl,nh,out_sl
self.gru_enc = nn.GRU(em_sz_enc, nh, num_layers=nl, dropout=0.25)
self.out_enc = nn.Linear(nh, em_sz_dec, bias=False)
self.emb_dec = create_emb(vecs_dec, itos_dec, em_sz_dec)
self.gru_dec = nn.GRU(em_sz_dec, em_sz_dec, num_layers=nl, dropout=0.1)
self.emb_enc_drop = nn.Dropout(0.15)
self.out_drop = nn.Dropout(0.35)
self.out = nn.Linear(em_sz_dec*2, len(itos_dec))
self.out.weight.data = self.emb_dec.weight.data
self.W1 = rand_p(nh, em_sz_dec)
self.l2 = nn.Linear(em_sz_dec, em_sz_dec)
self.l3 = nn.Linear(em_sz_dec+nh, em_sz_dec)
self.V = rand_p(em_sz_dec)
def forward(self, inp, y=None, ret_attn=False, s=2, beam=False):
sl,bs = inp.size()
h = self.initHidden(bs)
emb = self.emb_enc_drop(self.emb_enc(inp))
enc_out, h = self.gru_enc(emb, h)
h = self.out_enc(h)
dec_inp = V(torch.zeros(bs).long())
res,attns = [],[]
w1e = enc_out @ self.W1
cur_probs = None
for i in range(self.out_sl):
if not self.training and beam:
if i == 0:
cur_probs = to_gpu(torch.zeros(bs, s))
# calculate the topk probs for the current word
cur_w_topk_p, cur_w_topk_i, h = self._get_new_topk_h(w1e, enc_out, dec_inp, h, s, attns=attns)
# calculate the new topk probs for the entire sequence
cur_probs += cur_w_topk_p.log_()
# prepare dec_inp and h for next iteration
dec_inp = V(cur_w_topk_i.t().contiguous().view(-1))
h = h.repeat(1,s,1)
res = cur_w_topk_i.view(bs,1,-1)
# expand w1e and enc_out
w1e, enc_out = w1e.repeat(1,s,1), enc_out.repeat(1,s,1)
else:
# calculate the topk probs for the current word
cur_w_topk_p, cur_w_topk_i, h = self._get_new_topk_h(w1e, enc_out, dec_inp, h, s, attns=attns)
cur_w_topk_p, cur_w_topk_i = torch.cat(cur_w_topk_p.view(s, bs, s), 1), torch.cat(cur_w_topk_i.view(s, bs, s), 1)
# calculate the new topk probs for the entire sequence
cur_probs = cur_probs.repeat(1,s) + cur_w_topk_p.log_()
cur_probs, cur_topk_idxes = cur_probs.topk(s,dim=1)
# select topk prev results based on updated topk probs
old_res_idxes = cur_topk_idxes / s
new_res = []
for k in range(s):
new_res.append(res.gather(2, old_res_idxes[:,k:k+1].expand(bs, res.shape[1]).unsqueeze(-1)))
res = torch.cat(new_res, dim=2)
# select topk new words based on updated topk probs
cur_w_topk_id = cur_w_topk_i.gather(1, cur_topk_idxes)
# concat new words onto the topk seqs
res = torch.cat((res, cur_w_topk_id.unsqueeze(1)), dim=1)
# prepare dec_inp and h for next iteration
dec_inp = V(cur_w_topk_id.t().contiguous().view(-1))
h_idxes = to_gpu(torch.arange(bs)).long().repeat(s) + old_res_idxes.t().contiguous().view(-1) * bs
h.data = h.data.index_select(1,h_idxes)
else:
w2h = self.l2(h[-1])
u = F.tanh(w1e + w2h)
a = F.softmax(u @ self.V, 0)
attns.append(a)
Xa = (a.unsqueeze(2) * enc_out).sum(0)
emb = self.emb_dec(dec_inp)
wgt_enc = self.l3(torch.cat([emb, Xa], 1))
outp, h = self.gru_dec(wgt_enc.unsqueeze(0), h)
outp = self.out(self.out_drop(outp[0]))
res.append(outp)
dec_inp = V(outp.data.max(1)[1])
if (dec_inp==1).all(): break
if (y is not None) and (random.random()<self.pr_force):
if i>=len(y): break
dec_inp = y[i]
res = torch.stack(res)
if ret_attn: res = res,torch.stack(attns)
return res
def _get_new_topk_h(self, w1e, enc_out, dec_inp, h, s, attns=None):
w2h = self.l2(h[-1])
u = F.tanh(w1e + w2h)
a = F.softmax(u @ self.V, 0)
if attns:
attns.append(a)
Xa = (a.unsqueeze(2) * enc_out).sum(0)
emb = self.emb_enc(dec_inp)
wgt_enc = self.l3(torch.cat([emb, Xa], 1))
outp, h = self.gru_dec(wgt_enc.unsqueeze(0), h)
outp = self.out(self.out_drop(outp[0]))
outp = F.softmax(outp, dim=1)
return (*outp.data.topk(s, dim=1), h)
def initHidden(self, bs): return V(torch.zeros(self.nl, bs, self.nh))
learn.model.eval()
x,y = next(iter(val_dl))
probs,attns = learn.model(V(x),ret_attn=True)
preds = to_np(probs.max(2)[1])
preds_beam = learn.model(V(x),beam=True, s=4)
for i in range(170,180):
print(' '.join([fr_itos[o] for o in x[:,i] if o != 1]))
print(' '.join([en_itos[o] for o in y[:,i] if o != 1]))
print(' '.join([en_itos[o] for o in preds[:,i] if o!=1]))
print(' '.join([en_itos[o] for o in preds_beam[i,:,0] if o!=1]))
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment