Skip to content

Instantly share code, notes, and snippets.

@tomaarsen
Created April 20, 2023 12:33
Show Gist options
  • Save tomaarsen/3c355ef47ea90ec370300629c4080e2c to your computer and use it in GitHub Desktop.
Save tomaarsen/3c355ef47ea90ec370300629c4080e2c to your computer and use it in GitHub Desktop.
Logging losses for SetFit
import functools
from dataclasses import dataclass
from typing import Callable
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss
from torch import nn
import wandb
from setfit import SetFitModel, SetFitTrainer, sample_dataset
# Load a dataset from the Hugging Face Hub
dataset = load_dataset("sst2")
# Simulate the few-shot regime by sampling 8 examples per class
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
eval_dataset = dataset["validation"]
# Load a SetFit model from Hub
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
@dataclass
class LoggingWrapper:
loss_class: nn.Module
def __call__(self, *args, **kwargs):
wandb.init(project="setfit")
loss_class_instance = self.loss_class(*args, **kwargs)
loss_class_instance.forward = self.log_forward(loss_class_instance.forward)
return loss_class_instance
def log_forward(self, forward_func: Callable):
@functools.wraps(forward_func)
def log_wrapper_forward(*args, **kwargs):
loss = forward_func(*args, **kwargs)
wandb.log({"training_loss": loss})
return loss
return log_wrapper_forward
# Create trainer
trainer = SetFitTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
num_iterations=20,
column_mapping={"sentence": "text", "label": "label"},
loss_class=LoggingWrapper(CosineSimilarityLoss),
)
# Train and evaluate
trainer.train()
metrics = trainer.evaluate()
print(metrics)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment