Created
March 20, 2020 12:48
-
-
Save taoyds/73348b95a46c563f7fea7d2174756458 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
# List available models | |
torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt16.en-de', ... ] | |
# Load a transformer trained on WMT'16 En-De | |
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt16.en-de', tokenizer='moses', bpe='subword_nmt') | |
en2de.eval() # disable dropout | |
# The underlying model is available under the *models* attribute | |
assert isinstance(en2de.models[0], fairseq.models.transformer.TransformerModel) | |
# Move model to GPU for faster translation | |
en2de.cuda() | |
# Translate a sentence | |
en2de.translate('Hello world!') | |
# 'Hallo Welt!' | |
# Batched translation | |
en2de.translate(['Hello world!', 'The cat sat on the mat.']) | |
# ['Hallo Welt!', 'Die Katze saß auf der Matte.'] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment