Created
November 24, 2021 11:28
-
-
Save Chris-hughes10/2888c6641622364d2581895056692453 to your computer and use it in GitHub Desktop.
pytorch_accelerated_blog_metrics_in_callback
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://github.com/Chris-hughes10/pytorch-accelerated/blob/main/examples/metrics/train_with_metrics_in_callback.py | |
import os | |
from torch import nn, optim | |
from torch.utils.data import random_split | |
from torchmetrics import MetricCollection, Accuracy, Precision, Recall | |
from torchvision import transforms | |
from torchvision.datasets import MNIST | |
from pytorch_accelerated import Trainer | |
from pytorch_accelerated.callbacks import TrainerCallback | |
from pytorch_accelerated.trainer import DEFAULT_CALLBACKS | |
class MNISTModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.main = nn.Sequential( | |
nn.Linear(in_features=784, out_features=128), | |
nn.ReLU(), | |
nn.Linear(in_features=128, out_features=64), | |
nn.ReLU(), | |
nn.Linear(in_features=64, out_features=10), | |
) | |
def forward(self, x): | |
return self.main(x.view(x.shape[0], -1)) | |
class ClassificationMetricsCallback(TrainerCallback): | |
def __init__(self, num_classes): | |
self.metrics = MetricCollection( | |
{ | |
"accuracy": Accuracy(num_classes=num_classes), | |
"precision": Precision(num_classes=num_classes), | |
"recall": Recall(num_classes=num_classes), | |
} | |
) | |
def _move_to_device(self, trainer): | |
self.metrics.to(trainer.device) | |
def on_training_run_start(self, trainer, **kwargs): | |
self._move_to_device(trainer) | |
def on_evaluation_run_start(self, trainer, **kwargs): | |
self._move_to_device(trainer) | |
def on_eval_step_end(self, trainer, batch, batch_output, **kwargs): | |
preds = batch_output["model_outputs"].argmax(dim=-1) | |
self.metrics.update(preds, batch[1]) | |
def on_eval_epoch_end(self, trainer, **kwargs): | |
metrics = self.metrics.compute() | |
trainer.run_history.update_metric("accuracy", metrics["accuracy"].cpu()) | |
trainer.run_history.update_metric("precision", metrics["precision"].cpu()) | |
trainer.run_history.update_metric("recall", metrics["recall"].cpu()) | |
self.metrics.reset() | |
def main(): | |
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()) | |
num_classes = len(dataset.class_to_idx) | |
train_dataset, validation_dataset, test_dataset = random_split( | |
dataset, [50000, 5000, 5000] | |
) | |
model = MNISTModel() | |
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) | |
loss_func = nn.CrossEntropyLoss() | |
trainer = Trainer( | |
model, | |
loss_func=loss_func, | |
optimizer=optimizer, | |
callbacks=( | |
ClassificationMetricsCallback( | |
num_classes=num_classes, | |
), | |
*DEFAULT_CALLBACKS, | |
), | |
) | |
trainer.train( | |
train_dataset=train_dataset, | |
eval_dataset=validation_dataset, | |
num_epochs=2, | |
per_device_batch_size=32, | |
) | |
trainer.evaluate( | |
dataset=test_dataset, | |
per_device_batch_size=64, | |
) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment