Skip to content

Instantly share code, notes, and snippets.

@Chris-hughes10
Created November 24, 2021 11:22
Show Gist options
  • Save Chris-hughes10/0f7cb44d20d00b8c019b6c61b219b0c2 to your computer and use it in GitHub Desktop.
Save Chris-hughes10/0f7cb44d20d00b8c019b6c61b219b0c2 to your computer and use it in GitHub Desktop.
pytorch_accelerated_blog_metrics_callback_snippet
from torchmetrics import MetricCollection, Accuracy, Precision, Recall
from pytorch_accelerated.callbacks import TrainerCallback
class ClassificationMetricsCallback(TrainerCallback):
def __init__(self, num_classes):
self.metrics = MetricCollection(
{
"accuracy": Accuracy(num_classes=num_classes),
"precision": Precision(num_classes=num_classes),
"recall": Recall(num_classes=num_classes),
}
)
def _move_to_device(self, trainer):
self.metrics.to(trainer.device)
def on_training_run_start(self, trainer, **kwargs):
self._move_to_device(trainer)
def on_evaluation_run_start(self, trainer, **kwargs):
self._move_to_device(trainer)
def on_eval_step_end(self, trainer, batch, batch_output, **kwargs):
preds = batch_output["model_outputs"].argmax(dim=-1)
self.metrics.update(preds, batch[1])
def on_eval_epoch_end(self, trainer, **kwargs):
metrics = self.metrics.compute()
trainer.run_history.update_metric("accuracy", metrics["accuracy"].cpu())
trainer.run_history.update_metric("precision", metrics["precision"].cpu())
trainer.run_history.update_metric("recall", metrics["recall"].cpu())
self.metrics.reset()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment