Created
November 24, 2021 11:02
-
-
Save Chris-hughes10/9dac5f5905719d5702497f9666028f87 to your computer and use it in GitHub Desktop.
pytorch_accelerated_blog_trainer_metrics_snippet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torchmetrics import MetricCollection, Accuracy, Precision, Recall | |
class TrainerWithMetrics(Trainer): | |
def __init__(self, num_classes, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
# this will be moved to the correct device automatically by the | |
# MoveModulesToDeviceCallback callback, which is used by default | |
self.metrics = MetricCollection( | |
{ | |
"accuracy": Accuracy(num_classes=num_classes), | |
"precision": Precision(num_classes=num_classes), | |
"recall": Recall(num_classes=num_classes), | |
} | |
) | |
def calculate_eval_batch_loss(self, batch): | |
batch_output = super().calculate_eval_batch_loss(batch) | |
preds = batch_output["model_outputs"].argmax(dim=-1) | |
self.metrics.update(preds, batch[1]) | |
return batch_output | |
def eval_epoch_end(self): | |
metrics = self.metrics.compute() | |
self.run_history.update_metric("accuracy", metrics["accuracy"].cpu()) | |
self.run_history.update_metric("precision", metrics["precision"].cpu()) | |
self.run_history.update_metric("recall", metrics["recall"].cpu()) | |
self.metrics.reset() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment