Created
April 20, 2023 12:33
-
-
Save tomaarsen/3c355ef47ea90ec370300629c4080e2c to your computer and use it in GitHub Desktop.
Logging losses for SetFit
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
from dataclasses import dataclass | |
from typing import Callable | |
from datasets import load_dataset | |
from sentence_transformers.losses import CosineSimilarityLoss | |
from torch import nn | |
import wandb | |
from setfit import SetFitModel, SetFitTrainer, sample_dataset | |
# Load a dataset from the Hugging Face Hub | |
dataset = load_dataset("sst2") | |
# Simulate the few-shot regime by sampling 8 examples per class | |
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8) | |
eval_dataset = dataset["validation"] | |
# Load a SetFit model from Hub | |
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2") | |
@dataclass | |
class LoggingWrapper: | |
loss_class: nn.Module | |
def __call__(self, *args, **kwargs): | |
wandb.init(project="setfit") | |
loss_class_instance = self.loss_class(*args, **kwargs) | |
loss_class_instance.forward = self.log_forward(loss_class_instance.forward) | |
return loss_class_instance | |
def log_forward(self, forward_func: Callable): | |
@functools.wraps(forward_func) | |
def log_wrapper_forward(*args, **kwargs): | |
loss = forward_func(*args, **kwargs) | |
wandb.log({"training_loss": loss}) | |
return loss | |
return log_wrapper_forward | |
# Create trainer | |
trainer = SetFitTrainer( | |
model=model, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
num_iterations=20, | |
column_mapping={"sentence": "text", "label": "label"}, | |
loss_class=LoggingWrapper(CosineSimilarityLoss), | |
) | |
# Train and evaluate | |
trainer.train() | |
metrics = trainer.evaluate() | |
print(metrics) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment