Created
September 20, 2021 07:50
-
-
Save Narsil/b08ec8829049ac2850890837a99582ab to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datasets import load_dataset, Dataset as hf_Dataset | |
from transformers import pipeline | |
from transformers.pipelines.zero_shot_classification import Scoring | |
from tqdm import tqdm | |
from random import shuffle | |
from typing import List | |
from pprint import pprint | |
from collections import defaultdict | |
import json | |
import numpy as np | |
from torch.utils.data import Dataset, DataLoader | |
import torch | |
class BenchmarkDataset(Dataset): | |
def __init__(self, dataset: Dataset): | |
super().__init__() | |
self.dataset = dataset | |
if not isinstance(self.dataset, hf_Dataset): | |
raise ValueError("Test dataset should have a test partition.") | |
def __getitem__(self, idx): | |
return self.dataset[idx] | |
def __len__(self): | |
return len(self.dataset) | |
def collate_fn(self, batch): | |
examples = [self._convert_examples_to_premises(example) for example in batch] | |
labels = [self._verbalize(example["label"]) for example in batch] | |
return {"sequences": examples, "labels": labels, "candidate_labels": self.id2label} | |
def _convert_examples_to_premises(self, example): | |
raise NotImplementedError("Abstract method.") | |
def _verbalize(self, label): | |
if not hasattr(self, "id2label"): | |
raise AttributeError("This dataset has not implemented any verbalizer.") | |
return self.id2label[label] | |
def compute_metrics(self, predictions: List[str]): | |
# Accuracy by default | |
labels = np.array([self._verbalize(example["label"]) for example in self.dataset]) | |
predictions = np.array(predictions) | |
assert len(labels) == len(predictions) | |
return {"accuracy": 100 * np.mean(labels == predictions)} | |
def run_benchmark( | |
self, | |
clf: pipeline, | |
use_neutral: bool, | |
batch_size: int = 1, | |
hypothesis_template: str = None, | |
prefix: str = "mnli hypothesis: {hypothesis} premise: {premise}", | |
): | |
if hypothesis_template is None and hasattr(self, "default_template"): | |
hypothesis_template = self.default_template | |
dataloader = DataLoader(self, batch_size=batch_size, collate_fn=self.collate_fn, shuffle=False) | |
predictions = [] | |
for batch in tqdm(dataloader, total=len(self) // batch_size): | |
output = clf( | |
batch["sequences"], | |
candidate_labels=self.id2label, | |
hypothesis_template=hypothesis_template, | |
prefix=prefix, | |
scoring=Scoring.ENTAIL_CONTRADICT_NEUTRAL if use_neutral else Scoring.ENTAIL_CONTRADICT, | |
multi_label=True, | |
) | |
if isinstance(output, dict): | |
output = [output] | |
predictions.extend([out["labels"][0] for out in output]) | |
return self.compute_metrics(predictions) | |
class YelpBinaryBenchmark(BenchmarkDataset): | |
id2label = ["Negative", "Positve"] | |
default_template = "Polarity: {}." | |
def _convert_examples_to_premises(self, example): | |
return example["text"].replace("\\n", " ") | |
class YelpBenchmark(BenchmarkDataset): | |
id2label = ["1 star", "2 stars", "3 stars", "4 stars", "5 stars"] | |
default_template = "Review rating: {}." | |
def _convert_examples_to_premises(self, example): | |
return example["text"].replace("\\n", " ") | |
class AGNewsBenchmark(BenchmarkDataset): | |
id2label = ["World", "Sports", "Business", "Science & Technology"] | |
default_template = "Topic: {}." | |
def _convert_examples_to_premises(self, example): | |
return example["text"].replace("\\", " ") | |
class YahooBenchmark(BenchmarkDataset): | |
id2label = [ | |
"Society & Culture", | |
"Science & Mathematics", | |
"Health", | |
"Education & Reference", | |
"Computers & Internet", | |
"Sports", | |
"Business & Finance", | |
"Entertainment & Music", | |
"Family & Relationships", | |
"Politics & Government", | |
] | |
default_template = "Topic: {}." | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.dataset = self.dataset.map(lambda x: {"label": x["topic"]}) | |
def _convert_examples_to_premises(self, example): | |
text = " ".join([example["question_title"], example["question_content"], example["best_answer"]]) | |
return text.replace("\\n", " ") | |
EVALUATION_DATASETS = { | |
"ag_news": AGNewsBenchmark, | |
"yelp_polarity": YelpBinaryBenchmark, | |
"yelp_review_full": YelpBenchmark, | |
"yahoo_answers_topics": YahooBenchmark, | |
} | |
def test_benchmark_datasets(): | |
for name, _class in EVALUATION_DATASETS.items(): | |
benchmark = _class(load_dataset(name, split="test")) | |
dataloader = DataLoader(benchmark, batch_size=1, collate_fn=benchmark.collate_fn, shuffle=False) | |
predictions = [] | |
for batch in tqdm(dataloader, total=len(benchmark)): | |
predictions.extend(batch["labels"]) | |
assert benchmark.compute_metrics(predictions)["accuracy"] == 1.0 | |
print(benchmark.compute_metrics(predictions)) | |
shuffle(predictions) | |
assert benchmark.compute_metrics(predictions)["accuracy"] < 1.0 | |
print(benchmark.compute_metrics(predictions)) | |
def evaluate_models(models: List[str]): | |
results = defaultdict(dict) | |
for model, prefix, use_neutral in models: | |
# Load the pipeline | |
nlp = pipeline("zero-shot-classification", model=model, device=0 if torch.cuda.is_available() else -1) | |
for name, _class in EVALUATION_DATASETS.items(): | |
benchmark = _class(load_dataset(name, split="test")) | |
results[f"{model}_{prefix}_{use_neutral}_default".replace("__", "_")][name] = benchmark.run_benchmark( | |
nlp, batch_size=1, hypothesis_template="This example is {}.", use_neutral=use_neutral, prefix=prefix | |
) | |
with open("results.json", "wt") as f: | |
json.dump(results, f, indent=4) | |
pprint(results) | |
if __name__ == "__main__": | |
evaluate_models( | |
[ | |
("facebook/bart-large-mnli", "", True), | |
("facebook/bart-large-mnli", "", False), | |
] | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment