Last active
May 23, 2020 06:09
-
-
Save yuchenlin/208745b2ad71f0fca289fe97b52025af to your computer and use it in GitHub Desktop.
Batched version for using RoBERTa to do inference
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
from tqdm import tqdm | |
from fairseq.models.roberta import RobertaModel | |
from fairseq.data.data_utils import collate_tokens | |
from torch.utils.data import DataLoader, SequentialSampler | |
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli') | |
roberta.eval() | |
roberta.cuda() | |
batch_of_pairs = [ | |
['Roberta is a heavily optimized version of BERT.', 'Roberta is not very optimized.'], | |
['Roberta is a heavily optimized version of BERT.', 'Roberta is based on BERT.'], | |
['potatoes are awesome.', 'I like to run.'], | |
['Mars is very far from earth.', 'Mars is very close.'], | |
] | |
eval_dataset = collate_tokens( | |
[roberta.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1 | |
) | |
eval_sampler = SequentialSampler(eval_dataset) | |
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=32) | |
preds = None | |
for batch in tqdm(eval_dataloader, desc="Evaluating"): | |
logprobs = roberta.predict('mnli', batch) | |
if preds is None: | |
preds = logprobs.detach().cpu().numpy() | |
else: | |
preds = np.append(preds, logprobs.detach().cpu().numpy(), axis=0) | |
print(preds) | |
print(preds.argmax(axis=1)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment