Skip to content

Instantly share code, notes, and snippets.

@yuchenlin
Last active May 23, 2020 06:09
Show Gist options
  • Save yuchenlin/208745b2ad71f0fca289fe97b52025af to your computer and use it in GitHub Desktop.
Save yuchenlin/208745b2ad71f0fca289fe97b52025af to your computer and use it in GitHub Desktop.
Batched version for using RoBERTa to do inference
import torch
import numpy as np
from tqdm import tqdm
from fairseq.models.roberta import RobertaModel
from fairseq.data.data_utils import collate_tokens
from torch.utils.data import DataLoader, SequentialSampler
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli')
roberta.eval()
roberta.cuda()
batch_of_pairs = [
['Roberta is a heavily optimized version of BERT.', 'Roberta is not very optimized.'],
['Roberta is a heavily optimized version of BERT.', 'Roberta is based on BERT.'],
['potatoes are awesome.', 'I like to run.'],
['Mars is very far from earth.', 'Mars is very close.'],
]
eval_dataset = collate_tokens(
[roberta.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1
)
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=32)
preds = None
for batch in tqdm(eval_dataloader, desc="Evaluating"):
logprobs = roberta.predict('mnli', batch)
if preds is None:
preds = logprobs.detach().cpu().numpy()
else:
preds = np.append(preds, logprobs.detach().cpu().numpy(), axis=0)
print(preds)
print(preds.argmax(axis=1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment