Skip to content

Instantly share code, notes, and snippets.

@MLWhiz
Created August 26, 2020 05:20
Show Gist options
  • Select an option

  • Save MLWhiz/63e3858fdecbb16e2298549a04dba58a to your computer and use it in GitHub Desktop.

Select an option

Save MLWhiz/63e3858fdecbb16e2298549a04dba58a to your computer and use it in GitHub Desktop.
def greeedy_decode_sentence(model,sentence):
model.eval()
sentence = SRC.preprocess(sentence)
indexed = []
for tok in sentence:
if SRC.vocab.stoi[tok] != 0 :
indexed.append(SRC.vocab.stoi[tok])
else:
indexed.append(0)
sentence = Variable(torch.LongTensor([indexed])).cuda()
trg_init_tok = TGT.vocab.stoi[BOS_WORD]
trg = torch.LongTensor([[trg_init_tok]]).cuda()
translated_sentence = ""
maxlen = 25
for i in range(maxlen):
size = trg.size(0)
np_mask = torch.triu(torch.ones(size, size)==1).transpose(0,1)
np_mask = np_mask.float().masked_fill(np_mask == 0, float('-inf')).masked_fill(np_mask == 1, float(0.0))
np_mask = np_mask.cuda()
pred = model(sentence.transpose(0,1), trg, tgt_mask = np_mask)
add_word = TGT.vocab.itos[pred.argmax(dim=2)[-1]]
translated_sentence+=" "+add_word
if add_word==EOS_WORD:
break
trg = torch.cat((trg,torch.LongTensor([[pred.argmax(dim=2)[-1]]]).cuda()))
#print(trg)
return translated_sentence
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment