This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from fairseq.models.transformer import TransformerModel | |
| zh2en = TransformerModel.from_pretrained( | |
| '/path/to/checkpoints', | |
| checkpoint_file='checkpoint_best.pt', | |
| data_name_or_path='data-bin/wmt17_zh_en_full', | |
| bpe='subword_nmt', | |
| bpe_codes='data-bin/wmt17_zh_en_full/zh.code' | |
| ) | |
| zh2en.translate('你好 世界') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| # List available models | |
| torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt16.en-de', ... ] | |
| # Load a transformer trained on WMT'16 En-De | |
| en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt16.en-de', tokenizer='moses', bpe='subword_nmt') | |
| en2de.eval() # disable dropout | |
| # The underlying model is available under the *models* attribute |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| @register_task('classification') | |
| class ClassificationTask(FairseqTask): (...) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| model.train() | |
| model.set_num_updates(update_num) | |
| loss, sample_size, logging_output = criterion(model, sample) | |
| if ignore_grad: | |
| loss *= 0 | |
| optimizer.backward(loss) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # setup the task (e.g., load dictionaries) | |
| task = fairseq.tasks.setup_task(args) | |
| # build model and criterion | |
| model = task.build_model(args) | |
| criterion = task.build_criterion(args) | |
| # load datasets | |
| task.load_dataset('train') | |
| task.load_dataset('valid') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| for epoch in range(num_epochs): | |
| itr = task.get_batch_iterator(task.dataset('train')) | |
| for num_updates, batch in enumerate(itr): | |
| task.train_step(batch, model, criterion, optimizer) | |
| average_and_clip_gradients() | |
| optimizer.step() | |
| lr_scheduler.step_update(num_updates) | |
| lr_scheduler.step(epoch) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| if args.joined_dictionary: | |
| assert not args.srcdict or not args.tgtdict, \ | |
| "cannot use both --srcdict and --tgtdict with --joined-dictionary" | |
| if args.srcdict: | |
| src_dict = task.load_dictionary(args.srcdict) | |
| elif args.tgtdict: | |
| src_dict = task.load_dictionary(args.tgtdict) | |
| else: | |
| assert args.trainpref, "--trainpref must be set if --srcdict is not specified" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def cli_main(): | |
| parser = options.get_preprocessing_parser() | |
| args = parser.parse_args() | |
| main(args) | |
| if __name__ == "__main__": | |
| cli_main() |