Skip to content

Instantly share code, notes, and snippets.

@Narsil
Created September 20, 2021 07:50
Show Gist options
  • Save Narsil/b08ec8829049ac2850890837a99582ab to your computer and use it in GitHub Desktop.
Save Narsil/b08ec8829049ac2850890837a99582ab to your computer and use it in GitHub Desktop.
from datasets import load_dataset, Dataset as hf_Dataset
from transformers import pipeline
from transformers.pipelines.zero_shot_classification import Scoring
from tqdm import tqdm
from random import shuffle
from typing import List
from pprint import pprint
from collections import defaultdict
import json
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
class BenchmarkDataset(Dataset):
def __init__(self, dataset: Dataset):
super().__init__()
self.dataset = dataset
if not isinstance(self.dataset, hf_Dataset):
raise ValueError("Test dataset should have a test partition.")
def __getitem__(self, idx):
return self.dataset[idx]
def __len__(self):
return len(self.dataset)
def collate_fn(self, batch):
examples = [self._convert_examples_to_premises(example) for example in batch]
labels = [self._verbalize(example["label"]) for example in batch]
return {"sequences": examples, "labels": labels, "candidate_labels": self.id2label}
def _convert_examples_to_premises(self, example):
raise NotImplementedError("Abstract method.")
def _verbalize(self, label):
if not hasattr(self, "id2label"):
raise AttributeError("This dataset has not implemented any verbalizer.")
return self.id2label[label]
def compute_metrics(self, predictions: List[str]):
# Accuracy by default
labels = np.array([self._verbalize(example["label"]) for example in self.dataset])
predictions = np.array(predictions)
assert len(labels) == len(predictions)
return {"accuracy": 100 * np.mean(labels == predictions)}
def run_benchmark(
self,
clf: pipeline,
use_neutral: bool,
batch_size: int = 1,
hypothesis_template: str = None,
prefix: str = "mnli hypothesis: {hypothesis} premise: {premise}",
):
if hypothesis_template is None and hasattr(self, "default_template"):
hypothesis_template = self.default_template
dataloader = DataLoader(self, batch_size=batch_size, collate_fn=self.collate_fn, shuffle=False)
predictions = []
for batch in tqdm(dataloader, total=len(self) // batch_size):
output = clf(
batch["sequences"],
candidate_labels=self.id2label,
hypothesis_template=hypothesis_template,
prefix=prefix,
scoring=Scoring.ENTAIL_CONTRADICT_NEUTRAL if use_neutral else Scoring.ENTAIL_CONTRADICT,
multi_label=True,
)
if isinstance(output, dict):
output = [output]
predictions.extend([out["labels"][0] for out in output])
return self.compute_metrics(predictions)
class YelpBinaryBenchmark(BenchmarkDataset):
id2label = ["Negative", "Positve"]
default_template = "Polarity: {}."
def _convert_examples_to_premises(self, example):
return example["text"].replace("\\n", " ")
class YelpBenchmark(BenchmarkDataset):
id2label = ["1 star", "2 stars", "3 stars", "4 stars", "5 stars"]
default_template = "Review rating: {}."
def _convert_examples_to_premises(self, example):
return example["text"].replace("\\n", " ")
class AGNewsBenchmark(BenchmarkDataset):
id2label = ["World", "Sports", "Business", "Science & Technology"]
default_template = "Topic: {}."
def _convert_examples_to_premises(self, example):
return example["text"].replace("\\", " ")
class YahooBenchmark(BenchmarkDataset):
id2label = [
"Society & Culture",
"Science & Mathematics",
"Health",
"Education & Reference",
"Computers & Internet",
"Sports",
"Business & Finance",
"Entertainment & Music",
"Family & Relationships",
"Politics & Government",
]
default_template = "Topic: {}."
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dataset = self.dataset.map(lambda x: {"label": x["topic"]})
def _convert_examples_to_premises(self, example):
text = " ".join([example["question_title"], example["question_content"], example["best_answer"]])
return text.replace("\\n", " ")
EVALUATION_DATASETS = {
"ag_news": AGNewsBenchmark,
"yelp_polarity": YelpBinaryBenchmark,
"yelp_review_full": YelpBenchmark,
"yahoo_answers_topics": YahooBenchmark,
}
def test_benchmark_datasets():
for name, _class in EVALUATION_DATASETS.items():
benchmark = _class(load_dataset(name, split="test"))
dataloader = DataLoader(benchmark, batch_size=1, collate_fn=benchmark.collate_fn, shuffle=False)
predictions = []
for batch in tqdm(dataloader, total=len(benchmark)):
predictions.extend(batch["labels"])
assert benchmark.compute_metrics(predictions)["accuracy"] == 1.0
print(benchmark.compute_metrics(predictions))
shuffle(predictions)
assert benchmark.compute_metrics(predictions)["accuracy"] < 1.0
print(benchmark.compute_metrics(predictions))
def evaluate_models(models: List[str]):
results = defaultdict(dict)
for model, prefix, use_neutral in models:
# Load the pipeline
nlp = pipeline("zero-shot-classification", model=model, device=0 if torch.cuda.is_available() else -1)
for name, _class in EVALUATION_DATASETS.items():
benchmark = _class(load_dataset(name, split="test"))
results[f"{model}_{prefix}_{use_neutral}_default".replace("__", "_")][name] = benchmark.run_benchmark(
nlp, batch_size=1, hypothesis_template="This example is {}.", use_neutral=use_neutral, prefix=prefix
)
with open("results.json", "wt") as f:
json.dump(results, f, indent=4)
pprint(results)
if __name__ == "__main__":
evaluate_models(
[
("facebook/bart-large-mnli", "", True),
("facebook/bart-large-mnli", "", False),
]
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment