Created
December 9, 2021 07:43
-
-
Save Chris-hughes10/3cb7b46b24da2dfb5874a39abcfce9d2 to your computer and use it in GitHub Desktop.
Recommender Blog: Torchmetrics recommender callback
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pytorch_accelerated.callbacks import TrainerCallback | |
import torchmetrics | |
class RecommenderMetricsCallback(TrainerCallback): | |
def __init__(self): | |
self.metrics = torchmetrics.MetricCollection( | |
{ | |
"mse": torchmetrics.MeanSquaredError(), | |
"mae": torchmetrics.MeanAbsoluteError(), | |
} | |
) | |
def _move_to_device(self, trainer): | |
self.metrics.to(trainer.device) | |
def on_training_run_start(self, trainer, **kwargs): | |
self._move_to_device(trainer) | |
def on_evaluation_run_start(self, trainer, **kwargs): | |
self._move_to_device(trainer) | |
def on_eval_step_end(self, trainer, batch, batch_output, **kwargs): | |
preds = batch_output["model_outputs"] | |
self.metrics.update(preds, batch[1]) | |
def on_eval_epoch_end(self, trainer, **kwargs): | |
metrics = self.metrics.compute() | |
mse = metrics["mse"].cpu() | |
trainer.run_history.update_metric("mae", metrics["mae"].cpu()) | |
trainer.run_history.update_metric("mse", mse) | |
trainer.run_history.update_metric("rmse", math.sqrt(mse)) | |
self.metrics.reset() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment