Skip to content

Instantly share code, notes, and snippets.

@Chris-hughes10
Created November 24, 2021 11:02
Show Gist options
  • Save Chris-hughes10/9dac5f5905719d5702497f9666028f87 to your computer and use it in GitHub Desktop.
Save Chris-hughes10/9dac5f5905719d5702497f9666028f87 to your computer and use it in GitHub Desktop.
pytorch_accelerated_blog_trainer_metrics_snippet
from torchmetrics import MetricCollection, Accuracy, Precision, Recall
class TrainerWithMetrics(Trainer):
def __init__(self, num_classes, *args, **kwargs):
super().__init__(*args, **kwargs)
# this will be moved to the correct device automatically by the
# MoveModulesToDeviceCallback callback, which is used by default
self.metrics = MetricCollection(
{
"accuracy": Accuracy(num_classes=num_classes),
"precision": Precision(num_classes=num_classes),
"recall": Recall(num_classes=num_classes),
}
)
def calculate_eval_batch_loss(self, batch):
batch_output = super().calculate_eval_batch_loss(batch)
preds = batch_output["model_outputs"].argmax(dim=-1)
self.metrics.update(preds, batch[1])
return batch_output
def eval_epoch_end(self):
metrics = self.metrics.compute()
self.run_history.update_metric("accuracy", metrics["accuracy"].cpu())
self.run_history.update_metric("precision", metrics["precision"].cpu())
self.run_history.update_metric("recall", metrics["recall"].cpu())
self.metrics.reset()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment