Skip to content

Instantly share code, notes, and snippets.

@taoyds
Created March 20, 2020 12:48
Show Gist options
  • Save taoyds/73348b95a46c563f7fea7d2174756458 to your computer and use it in GitHub Desktop.
Save taoyds/73348b95a46c563f7fea7d2174756458 to your computer and use it in GitHub Desktop.
import torch
# List available models
torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt16.en-de', ... ]
# Load a transformer trained on WMT'16 En-De
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt16.en-de', tokenizer='moses', bpe='subword_nmt')
en2de.eval() # disable dropout
# The underlying model is available under the *models* attribute
assert isinstance(en2de.models[0], fairseq.models.transformer.TransformerModel)
# Move model to GPU for faster translation
en2de.cuda()
# Translate a sentence
en2de.translate('Hello world!')
# 'Hallo Welt!'
# Batched translation
en2de.translate(['Hello world!', 'The cat sat on the mat.'])
# ['Hallo Welt!', 'Die Katze saß auf der Matte.']
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment