Skip to content

Instantly share code, notes, and snippets.

@nahidalam
Created July 24, 2020 17:54
Show Gist options
  • Save nahidalam/5d2ca59815c51aa17e918d0e7668a86c to your computer and use it in GitHub Desktop.
Save nahidalam/5d2ca59815c51aa17e918d0e7668a86c to your computer and use it in GitHub Desktop.
'''
Reusable set of functions to convert a tuple of strings (pair) to tensors
Reference: https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
'''
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def indexesFromSentence(lang, sentence):
return [lang.word2index[word] for word in sentence.split(' ')]
def tensorFromSentence(lang, sentence):
indexes = indexesFromSentence(lang, sentence)
indexes.append(EOS_token)
return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)
def tensorsFromPair(pair):
input_tensor = tensorFromSentence(input_lang, pair[0])
target_tensor = tensorFromSentence(output_lang, pair[1])
return (input_tensor, target_tensor)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment