Created
May 5, 2025 16:42
-
-
Save ericspod/0bf98f586a6bdd0979b2eca002351e35 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "b29ff959", | |
"metadata": {}, | |
"source": [ | |
"# Custom Metrics and Handlers\n", | |
"\n", | |
"This notebook discusses the construction of custom metrics/handlers to be used with the engine classes in MONAI. Many \n", | |
"metrics are provided by MONAI, which can also access many from MetricsReloaded as well. Within the current framework,\n", | |
"metric classes represent the calculation of metric values for items, their storage, and then their aggregation across \n", | |
"iterations or epochs. The handlers used here trigger the use of these metrics when events happen within the engine,\n", | |
"such as when an iteration or epoch finishes, to drive the computation of metric values over these events. The handlers\n", | |
"also are responsible for aggregating results across multiple processes. Some of the nomenclature in MONAI is a bit\n", | |
"confusing, due to historical reasons, but this tutorial will help clarify the use of these components.\n", | |
"\n", | |
"First thing to do is imports and the creation of a synthetic dataset in memory:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "affafc4b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import monai\n", | |
"import monai.transforms as mt\n", | |
"from monai.data import create_test_image_3d, Dataset, DataLoader\n", | |
"from monai.handlers import (\n", | |
" MeanDice,\n", | |
" ValidationHandler,\n", | |
" StatsHandler,\n", | |
" IgniteMetricHandler,\n", | |
" MetricLogger,\n", | |
" MetricLoggerKeys,\n", | |
" from_engine,\n", | |
")\n", | |
"from monai.metrics import LossMetric\n", | |
"from monai.utils.enums import CommonKeys\n", | |
"from monai.engines import SupervisedEvaluator, SupervisedTrainer\n", | |
"\n", | |
"\n", | |
"both = (CommonKeys.IMAGE, CommonKeys.LABEL)\n", | |
"\n", | |
"data = []\n", | |
"\n", | |
"for i in range(50):\n", | |
" im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=0)\n", | |
" data.append({CommonKeys.IMAGE: im, CommonKeys.LABEL: seg})\n", | |
"\n", | |
"\n", | |
"train_transforms = mt.Compose(\n", | |
" [\n", | |
" mt.ScaleIntensityd(keys=CommonKeys.IMAGE),\n", | |
" mt.RandRotate90d(keys=both, prob=0.5, spatial_axes=[0, 2]),\n", | |
" ]\n", | |
")\n", | |
"val_transforms = mt.Compose([mt.ScaleIntensityd(keys=CommonKeys.IMAGE)])\n", | |
"\n", | |
"train_ds = Dataset(data=data[:40], transform=train_transforms)\n", | |
"val_ds = Dataset(data=data[40:], transform=val_transforms)\n", | |
"train_loader = DataLoader(train_ds, batch_size=5, shuffle=True, num_workers=0)\n", | |
"val_loader = DataLoader(val_ds, batch_size=5, shuffle=False, num_workers=0)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "7fa1cbaa", | |
"metadata": {}, | |
"source": [ | |
"Next the network and training components are defined, it doesn't matter what these are too much since the data is synthetic:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "aa15bfd9", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | |
"net = monai.networks.nets.UNet(\n", | |
" spatial_dims=3,\n", | |
" in_channels=1,\n", | |
" out_channels=1,\n", | |
" channels=(8, 16, 32, 64),\n", | |
" strides=(2, 2, 2),\n", | |
" num_res_units=2,\n", | |
").to(device)\n", | |
"\n", | |
"loss = monai.losses.DiceLoss(sigmoid=True)\n", | |
"opt = torch.optim.Adam(net.parameters(), lr=2e-4)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "dd86231a", | |
"metadata": {}, | |
"source": [ | |
"The first important class is the `SupervisedEvaluator` which is responsible for passing all data from its given loader\n", | |
"through the network to give handlers a chance to respond to the data. The expectation is that handlers like `IgniteMetricHandler`\n", | |
"will compute useful things during these events, eg. compute a metric score for each iteration then aggregate these at the\n", | |
"end of a epoch. We'll define a simple one which has only one metric handler class, `MeanDice` (which is not a metric class\n", | |
"despite its name):" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "a441584b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"val_post_transform = mt.Compose( # used to post-process outputs from the network\n", | |
" [mt.Activationsd(keys=CommonKeys.PRED, sigmoid=True), mt.AsDiscreted(keys=CommonKeys.PRED, threshold=0.5)]\n", | |
")\n", | |
"\n", | |
"val_handlers = [StatsHandler(output_transform=lambda x: None)] # used to generate printout values\n", | |
"\n", | |
"output_transform = from_engine([CommonKeys.PRED, CommonKeys.LABEL]) # defines how to pull data from the engine\n", | |
"\n", | |
"# the object keyed to \"val_mean_dice\" will get attached to the evaluator and knows which events to register itself to\n", | |
"val_metrics = {\n", | |
" \"val_mean_dice\": MeanDice(include_background=True, output_transform=output_transform),\n", | |
"}\n", | |
"\n", | |
"evaluator = SupervisedEvaluator(\n", | |
" device=device,\n", | |
" val_data_loader=val_loader,\n", | |
" network=net,\n", | |
" postprocessing=val_post_transform,\n", | |
" key_val_metric=val_metrics,\n", | |
" val_handlers=val_handlers,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "baa270f7", | |
"metadata": {}, | |
"source": [ | |
"This can then be run directly, but will be used directly in training as well:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "e70677ea", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2025-05-05 17:29:46,714 - INFO - Epoch[1] Metrics -- val_mean_dice: 0.3506 \n", | |
"2025-05-05 17:29:46,715 - INFO - Key metric: val_mean_dice best value: 0.35056644678115845 at epoch: 1\n", | |
"{'val_mean_dice': 0.35056644678115845}\n" | |
] | |
} | |
], | |
"source": [ | |
"evaluator.run()\n", | |
"print(evaluator.state.metrics)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "0b43bf24", | |
"metadata": {}, | |
"source": [ | |
"The network is untrained so the metric is expectedly poor. This can be combined with a `SupervisedTrainer` class to \n", | |
"train the network and trigger validation at selected intervals:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "c2fd932e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2025-05-05 17:29:48,048 - INFO - Epoch[1] Metrics -- val_mean_dice: 0.6221 \n", | |
"2025-05-05 17:29:48,049 - INFO - Key metric: val_mean_dice best value: 0.6220798492431641 at epoch: 1\n", | |
"2025-05-05 17:29:49,313 - INFO - Epoch[2] Metrics -- val_mean_dice: 0.7896 \n", | |
"2025-05-05 17:29:49,314 - INFO - Key metric: val_mean_dice best value: 0.7896313071250916 at epoch: 2\n", | |
"2025-05-05 17:29:50,599 - INFO - Epoch[3] Metrics -- val_mean_dice: 0.8213 \n", | |
"2025-05-05 17:29:50,600 - INFO - Key metric: val_mean_dice best value: 0.8212922811508179 at epoch: 3\n", | |
"2025-05-05 17:29:51,862 - INFO - Epoch[4] Metrics -- val_mean_dice: 0.8323 \n", | |
"2025-05-05 17:29:51,863 - INFO - Key metric: val_mean_dice best value: 0.832302451133728 at epoch: 4\n", | |
"2025-05-05 17:29:53,131 - INFO - Epoch[5] Metrics -- val_mean_dice: 0.8390 \n", | |
"2025-05-05 17:29:53,132 - INFO - Key metric: val_mean_dice best value: 0.8390257954597473 at epoch: 5\n", | |
"2025-05-05 17:29:54,404 - INFO - Epoch[6] Metrics -- val_mean_dice: 0.8396 \n", | |
"2025-05-05 17:29:54,405 - INFO - Key metric: val_mean_dice best value: 0.8395887613296509 at epoch: 6\n", | |
"2025-05-05 17:29:55,662 - INFO - Epoch[7] Metrics -- val_mean_dice: 0.8431 \n", | |
"2025-05-05 17:29:55,663 - INFO - Key metric: val_mean_dice best value: 0.8430854678153992 at epoch: 7\n", | |
"2025-05-05 17:29:56,929 - INFO - Epoch[8] Metrics -- val_mean_dice: 0.8467 \n", | |
"2025-05-05 17:29:56,931 - INFO - Key metric: val_mean_dice best value: 0.8466619253158569 at epoch: 8\n", | |
"2025-05-05 17:29:58,198 - INFO - Epoch[9] Metrics -- val_mean_dice: 0.8560 \n", | |
"2025-05-05 17:29:58,199 - INFO - Key metric: val_mean_dice best value: 0.8559762835502625 at epoch: 9\n", | |
"2025-05-05 17:29:59,445 - INFO - Epoch[10] Metrics -- val_mean_dice: 0.8620 \n", | |
"2025-05-05 17:29:59,446 - INFO - Key metric: val_mean_dice best value: 0.8619985580444336 at epoch: 10\n" | |
] | |
} | |
], | |
"source": [ | |
"train_handlers = [\n", | |
" ValidationHandler(validator=evaluator, interval=1, epoch_level=True), # triggers evaluator at selected times\n", | |
"]\n", | |
"trainer = SupervisedTrainer(\n", | |
" device=device,\n", | |
" max_epochs=10,\n", | |
" train_data_loader=train_loader,\n", | |
" network=net,\n", | |
" optimizer=opt,\n", | |
" loss_function=loss,\n", | |
" key_train_metric=None,\n", | |
" train_handlers=train_handlers,\n", | |
")\n", | |
"\n", | |
"trainer.run()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "3b992b7d", | |
"metadata": {}, | |
"source": [ | |
"The `val_metrics` dictionary can contain other metrics keyed to different names. The first one is shown as a\n", | |
"\"key metric\" whose improvement is used by other handlers (eg. `CheckpointSaver`) to decide when to save the best weights \n", | |
"for the network being trained. Other metrics can be added and these will be included in the log information. \n", | |
"\n", | |
"There are a number of ways of defining a custom metric to be computed, the more difficult way is to defined a metric\n", | |
"class and a handler class to make use of it. A second approach is to use `IgniteMetricHandler` to accept a callable\n", | |
"object as a loss function and use this as a decreasing metric. The evaluator can thus be augmented with this handler:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "9d2f3044", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2025-05-05 17:29:59,652 - INFO - Epoch[1] Metrics -- l1_loss: 0.0365 val_mean_dice: 0.8620 \n", | |
"2025-05-05 17:29:59,653 - INFO - Key metric: val_mean_dice best value: 0.8619985580444336 at epoch: 1\n" | |
] | |
} | |
], | |
"source": [ | |
"val_metrics = {\n", | |
" \"val_mean_dice\": MeanDice(include_background=True, output_transform=output_transform),\n", | |
" \"l1_loss\": IgniteMetricHandler(loss_fn=torch.nn.functional.l1_loss, output_transform=output_transform),\n", | |
"}\n", | |
"\n", | |
"evaluator = SupervisedEvaluator(\n", | |
" device=device,\n", | |
" val_data_loader=val_loader,\n", | |
" network=net,\n", | |
" postprocessing=val_post_transform,\n", | |
" key_val_metric=val_metrics,\n", | |
" val_handlers=val_handlers,\n", | |
")\n", | |
"\n", | |
"evaluator.run()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e00baa87", | |
"metadata": {}, | |
"source": [ | |
"Despite the type of the `loss_fn` argument being `_Loss`, this can be any callable:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "04d33550", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2025-05-05 17:29:59,847 - INFO - Epoch[1] Metrics -- l1_loss: 0.0365 my_l1_loss: 0.0365 val_mean_dice: 0.8620 \n", | |
"2025-05-05 17:29:59,847 - INFO - Key metric: val_mean_dice best value: 0.8619985580444336 at epoch: 1\n" | |
] | |
} | |
], | |
"source": [ | |
"val_metrics = {\n", | |
" \"val_mean_dice\": MeanDice(include_background=True, output_transform=output_transform),\n", | |
" \"l1_loss\": IgniteMetricHandler(loss_fn=torch.nn.functional.l1_loss, output_transform=output_transform),\n", | |
" \"my_l1_loss\": IgniteMetricHandler(\n", | |
" loss_fn=lambda pred, label: pred.sub(label).abs_().mean(dim=[1, 2, 3, 4]), # same as L1 loss above\n", | |
" output_transform=output_transform,\n", | |
" ),\n", | |
"}\n", | |
"\n", | |
"evaluator = SupervisedEvaluator(\n", | |
" device=device,\n", | |
" val_data_loader=val_loader,\n", | |
" network=net,\n", | |
" postprocessing=val_post_transform,\n", | |
" key_val_metric=val_metrics,\n", | |
" val_handlers=val_handlers,\n", | |
")\n", | |
"\n", | |
"evaluator.run()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "ec68b5a5", | |
"metadata": {}, | |
"source": [ | |
"A function with increasing value as the network improves can be used instead of an actual loss function, thus allowing\n", | |
"the `IgniteMetricHandler` to essentially integrate any callable computing a metric into MONAI's engine classes. The \n", | |
"typing and naming isn't strictly accurate which will be addressed in later versions of the library. In the following\n", | |
"example, a function `equal_voxel_metric` computes the ratio of how many voxels match between the prediction and target,\n", | |
"which should increase as the network improves:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "2d991eb7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2025-05-05 17:30:01,147 - INFO - Epoch[1] Metrics -- eq_vox_metric: 0.9652 eq_vox_metric_alt: 0.9652 l1_loss: 0.0348 my_l1_loss: 0.0348 val_mean_dice: 0.8678 \n", | |
"2025-05-05 17:30:01,148 - INFO - Key metric: val_mean_dice best value: 0.8677994012832642 at epoch: 1\n", | |
"2025-05-05 17:30:02,435 - INFO - Epoch[2] Metrics -- eq_vox_metric: 0.9669 eq_vox_metric_alt: 0.9669 l1_loss: 0.0331 my_l1_loss: 0.0331 val_mean_dice: 0.8734 \n", | |
"2025-05-05 17:30:02,436 - INFO - Key metric: val_mean_dice best value: 0.8733541369438171 at epoch: 2\n", | |
"2025-05-05 17:30:03,694 - INFO - Epoch[3] Metrics -- eq_vox_metric: 0.9680 eq_vox_metric_alt: 0.9680 l1_loss: 0.0320 my_l1_loss: 0.0320 val_mean_dice: 0.8770 \n", | |
"2025-05-05 17:30:03,695 - INFO - Key metric: val_mean_dice best value: 0.876957893371582 at epoch: 3\n", | |
"2025-05-05 17:30:04,968 - INFO - Epoch[4] Metrics -- eq_vox_metric: 0.9690 eq_vox_metric_alt: 0.9690 l1_loss: 0.0310 my_l1_loss: 0.0310 val_mean_dice: 0.8806 \n", | |
"2025-05-05 17:30:04,969 - INFO - Key metric: val_mean_dice best value: 0.8806190490722656 at epoch: 4\n", | |
"2025-05-05 17:30:06,242 - INFO - Epoch[5] Metrics -- eq_vox_metric: 0.9701 eq_vox_metric_alt: 0.9701 l1_loss: 0.0299 my_l1_loss: 0.0299 val_mean_dice: 0.8845 \n", | |
"2025-05-05 17:30:06,242 - INFO - Key metric: val_mean_dice best value: 0.8844831585884094 at epoch: 5\n", | |
"2025-05-05 17:30:07,539 - INFO - Epoch[6] Metrics -- eq_vox_metric: 0.9709 eq_vox_metric_alt: 0.9709 l1_loss: 0.0291 my_l1_loss: 0.0291 val_mean_dice: 0.8872 \n", | |
"2025-05-05 17:30:07,539 - INFO - Key metric: val_mean_dice best value: 0.8872219920158386 at epoch: 6\n", | |
"2025-05-05 17:30:08,818 - INFO - Epoch[7] Metrics -- eq_vox_metric: 0.9721 eq_vox_metric_alt: 0.9721 l1_loss: 0.0279 my_l1_loss: 0.0279 val_mean_dice: 0.8914 \n", | |
"2025-05-05 17:30:08,819 - INFO - Key metric: val_mean_dice best value: 0.8914018869400024 at epoch: 7\n", | |
"2025-05-05 17:30:10,126 - INFO - Epoch[8] Metrics -- eq_vox_metric: 0.9726 eq_vox_metric_alt: 0.9726 l1_loss: 0.0274 my_l1_loss: 0.0274 val_mean_dice: 0.8933 \n", | |
"2025-05-05 17:30:10,127 - INFO - Key metric: val_mean_dice best value: 0.8933206796646118 at epoch: 8\n", | |
"2025-05-05 17:30:11,379 - INFO - Epoch[9] Metrics -- eq_vox_metric: 0.9734 eq_vox_metric_alt: 0.9734 l1_loss: 0.0266 my_l1_loss: 0.0266 val_mean_dice: 0.8959 \n", | |
"2025-05-05 17:30:11,380 - INFO - Key metric: val_mean_dice best value: 0.8958585858345032 at epoch: 9\n", | |
"2025-05-05 17:30:12,641 - INFO - Epoch[10] Metrics -- eq_vox_metric: 0.9742 eq_vox_metric_alt: 0.9742 l1_loss: 0.0258 my_l1_loss: 0.0258 val_mean_dice: 0.8989 \n", | |
"2025-05-05 17:30:12,642 - INFO - Key metric: val_mean_dice best value: 0.8988854289054871 at epoch: 10\n" | |
] | |
} | |
], | |
"source": [ | |
"def equal_voxel_metric(y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n", | |
" diff = y_pred.sub(y).abs_()\n", | |
" return diff.le(0.1).sum(dim=[1, 2, 3, 4]) / y[0].numel()\n", | |
"\n", | |
"\n", | |
"val_metrics = {\n", | |
" \"val_mean_dice\": MeanDice(include_background=True, output_transform=output_transform),\n", | |
" \"l1_loss\": IgniteMetricHandler(loss_fn=torch.nn.functional.l1_loss, output_transform=output_transform),\n", | |
" \"my_l1_loss\": IgniteMetricHandler(\n", | |
" loss_fn=lambda pred, label: pred.sub(label).abs().mean(dim=[1, 2, 3, 4]), # same as L1 loss above\n", | |
" output_transform=output_transform,\n", | |
" ),\n", | |
" \"eq_vox_metric\": IgniteMetricHandler(loss_fn=equal_voxel_metric, output_transform=output_transform),\n", | |
" # functionally identical to the above, just wrapping the function in a metric class and using metric_fn argument\n", | |
" \"eq_vox_metric_alt\": IgniteMetricHandler(\n", | |
" metric_fn=LossMetric(loss_fn=equal_voxel_metric), output_transform=output_transform\n", | |
" ),\n", | |
"}\n", | |
"\n", | |
"evaluator = SupervisedEvaluator(\n", | |
" device=device,\n", | |
" val_data_loader=val_loader,\n", | |
" network=net,\n", | |
" postprocessing=val_post_transform,\n", | |
" key_val_metric=val_metrics,\n", | |
" val_handlers=val_handlers,\n", | |
")\n", | |
"\n", | |
"train_handlers = [\n", | |
" ValidationHandler(validator=evaluator, interval=1, epoch_level=True),\n", | |
"]\n", | |
"trainer = SupervisedTrainer(\n", | |
" device=device,\n", | |
" max_epochs=10,\n", | |
" train_data_loader=train_loader,\n", | |
" network=net,\n", | |
" optimizer=opt,\n", | |
" loss_function=loss,\n", | |
" key_train_metric=None,\n", | |
" train_handlers=train_handlers,\n", | |
")\n", | |
"trainer.run()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e4106680", | |
"metadata": {}, | |
"source": [ | |
"Metric values are typically reduced using mean for each validation epoch, instead to get the score for each item in the validation set the \"reduction\" argument for the metric can be set to \"none\". This will produce a tensor of results rather than a single value, this can't be used to choose better weights but can be captured using a `MetricLogger` handler in the trainer. This will collect the per-iteration loss value during training, but also whatever metrics are produced by running the evaluator which it also attaches to. If these are tensors of results they can be stacked to produce a matrix of results:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "b964e855", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2025-05-05 17:30:13,949 - INFO - Epoch[1] Metrics -- mean_dice_unreduced: tensor([0.8950, 0.9043, 0.9037, 0.9214, 0.8540, 0.9374, 0.9192, 0.9079, 0.8385,\n", | |
" 0.9233])val_mean_dice: 0.9005 \n", | |
"2025-05-05 17:30:13,950 - INFO - Key metric: val_mean_dice best value: 0.9004715085029602 at epoch: 1\n", | |
"2025-05-05 17:30:15,211 - INFO - Epoch[2] Metrics -- mean_dice_unreduced: tensor([0.8971, 0.9062, 0.9056, 0.9229, 0.8568, 0.9384, 0.9208, 0.9097, 0.8414,\n", | |
" 0.9248])val_mean_dice: 0.9024 \n", | |
"2025-05-05 17:30:15,212 - INFO - Key metric: val_mean_dice best value: 0.9023711085319519 at epoch: 2\n", | |
"2025-05-05 17:30:16,470 - INFO - Epoch[3] Metrics -- mean_dice_unreduced: tensor([0.8996, 0.9085, 0.9080, 0.9247, 0.8597, 0.9397, 0.9226, 0.9118, 0.8447,\n", | |
" 0.9264])val_mean_dice: 0.9046 \n", | |
"2025-05-05 17:30:16,471 - INFO - Key metric: val_mean_dice best value: 0.9045838117599487 at epoch: 3\n", | |
"2025-05-05 17:30:17,737 - INFO - Epoch[4] Metrics -- mean_dice_unreduced: tensor([0.9016, 0.9104, 0.9100, 0.9262, 0.8624, 0.9409, 0.9243, 0.9137, 0.8475,\n", | |
" 0.9281])val_mean_dice: 0.9065 \n", | |
"2025-05-05 17:30:17,739 - INFO - Key metric: val_mean_dice best value: 0.9065226316452026 at epoch: 4\n", | |
"2025-05-05 17:30:18,995 - INFO - Epoch[5] Metrics -- mean_dice_unreduced: tensor([0.9036, 0.9123, 0.9117, 0.9277, 0.8651, 0.9419, 0.9258, 0.9155, 0.8504,\n", | |
" 0.9293])val_mean_dice: 0.9083 \n", | |
"2025-05-05 17:30:18,996 - INFO - Key metric: val_mean_dice best value: 0.9083343744277954 at epoch: 5\n", | |
"2025-05-05 17:30:20,238 - INFO - Epoch[6] Metrics -- mean_dice_unreduced: tensor([0.9050, 0.9136, 0.9131, 0.9288, 0.8668, 0.9427, 0.9268, 0.9166, 0.8524,\n", | |
" 0.9303])val_mean_dice: 0.9096 \n", | |
"2025-05-05 17:30:20,238 - INFO - Key metric: val_mean_dice best value: 0.909624457359314 at epoch: 6\n", | |
"2025-05-05 17:30:21,494 - INFO - Epoch[7] Metrics -- mean_dice_unreduced: tensor([0.9062, 0.9148, 0.9141, 0.9297, 0.8687, 0.9435, 0.9277, 0.9178, 0.8543,\n", | |
" 0.9313])val_mean_dice: 0.9108 \n", | |
"2025-05-05 17:30:21,495 - INFO - Key metric: val_mean_dice best value: 0.9108176231384277 at epoch: 7\n", | |
"2025-05-05 17:30:22,756 - INFO - Epoch[8] Metrics -- mean_dice_unreduced: tensor([0.9077, 0.9162, 0.9155, 0.9309, 0.8705, 0.9444, 0.9290, 0.9192, 0.8565,\n", | |
" 0.9325])val_mean_dice: 0.9122 \n", | |
"2025-05-05 17:30:22,757 - INFO - Key metric: val_mean_dice best value: 0.9122360944747925 at epoch: 8\n", | |
"2025-05-05 17:30:24,002 - INFO - Epoch[9] Metrics -- mean_dice_unreduced: tensor([0.9084, 0.9170, 0.9162, 0.9316, 0.8717, 0.9450, 0.9297, 0.9199, 0.8577,\n", | |
" 0.9328])val_mean_dice: 0.9130 \n", | |
"2025-05-05 17:30:24,003 - INFO - Key metric: val_mean_dice best value: 0.9129890203475952 at epoch: 9\n", | |
"2025-05-05 17:30:25,268 - INFO - Epoch[10] Metrics -- mean_dice_unreduced: tensor([0.9104, 0.9187, 0.9180, 0.9331, 0.8740, 0.9461, 0.9312, 0.9217, 0.8601,\n", | |
" 0.9344])val_mean_dice: 0.9148 \n", | |
"2025-05-05 17:30:25,269 - INFO - Key metric: val_mean_dice best value: 0.9147800207138062 at epoch: 10\n", | |
"(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80)\n", | |
"80 tensor([0.6078, 0.5158, 0.5304, 0.5080, 0.5131, 0.5352, 0.4862, 0.5256, 0.5391,\n", | |
" 0.5270, 0.5423, 0.5704, 0.4621, 0.5293, 0.5144, 0.5249, 0.5524, 0.5583,\n", | |
" 0.4687, 0.5701, 0.5122, 0.5322, 0.4755, 0.5272, 0.4996, 0.5954, 0.5391,\n", | |
" 0.4912, 0.5403, 0.4943, 0.5372, 0.4869, 0.5045, 0.5060, 0.5126, 0.5200,\n", | |
" 0.5702, 0.5663, 0.5033, 0.4887, 0.5495, 0.4799, 0.5893, 0.4794, 0.5138,\n", | |
" 0.4992, 0.5294, 0.5185, 0.4669, 0.5457, 0.5104, 0.5062, 0.5591, 0.4863,\n", | |
" 0.5724, 0.4996, 0.5101, 0.5532, 0.5191, 0.5161, 0.5185, 0.4770, 0.5111,\n", | |
" 0.5288, 0.5263, 0.4828, 0.5375, 0.5140, 0.4622, 0.5711, 0.4987, 0.5289,\n", | |
" 0.4907, 0.5026, 0.5316, 0.5341, 0.4872, 0.5007, 0.5383, 0.5240])\n", | |
"(8, 16, 24, 32, 40, 48, 56, 64, 72, 80)\n", | |
"10 tensor([0.9005, 0.9024, 0.9046, 0.9065, 0.9083, 0.9096, 0.9108, 0.9122, 0.9130,\n", | |
" 0.9148])\n", | |
"10 tensor([[0.8950, 0.9043, 0.9037, 0.9214, 0.8540, 0.9374, 0.9192, 0.9079, 0.8385,\n", | |
" 0.9233],\n", | |
" [0.8971, 0.9062, 0.9056, 0.9229, 0.8568, 0.9384, 0.9208, 0.9097, 0.8414,\n", | |
" 0.9248],\n", | |
" [0.8996, 0.9085, 0.9080, 0.9247, 0.8597, 0.9397, 0.9226, 0.9118, 0.8447,\n", | |
" 0.9264],\n", | |
" [0.9016, 0.9104, 0.9100, 0.9262, 0.8624, 0.9409, 0.9243, 0.9137, 0.8475,\n", | |
" 0.9281],\n", | |
" [0.9036, 0.9123, 0.9117, 0.9277, 0.8651, 0.9419, 0.9258, 0.9155, 0.8504,\n", | |
" 0.9293],\n", | |
" [0.9050, 0.9136, 0.9131, 0.9288, 0.8668, 0.9427, 0.9268, 0.9166, 0.8524,\n", | |
" 0.9303],\n", | |
" [0.9062, 0.9148, 0.9141, 0.9297, 0.8687, 0.9435, 0.9277, 0.9178, 0.8543,\n", | |
" 0.9313],\n", | |
" [0.9077, 0.9162, 0.9155, 0.9309, 0.8705, 0.9444, 0.9290, 0.9192, 0.8565,\n", | |
" 0.9325],\n", | |
" [0.9084, 0.9170, 0.9162, 0.9316, 0.8717, 0.9450, 0.9297, 0.9199, 0.8577,\n", | |
" 0.9328],\n", | |
" [0.9104, 0.9187, 0.9180, 0.9331, 0.8740, 0.9461, 0.9312, 0.9217, 0.8601,\n", | |
" 0.9344]])\n" | |
] | |
} | |
], | |
"source": [ | |
"val_metrics = {\n", | |
" \"val_mean_dice\": MeanDice(include_background=True, output_transform=output_transform),\n", | |
" \"mean_dice_unreduced\": MeanDice(include_background=True, output_transform=output_transform, reduction=\"none\"),\n", | |
"}\n", | |
"\n", | |
"evaluator = SupervisedEvaluator(\n", | |
" device=device,\n", | |
" val_data_loader=val_loader,\n", | |
" network=net,\n", | |
" postprocessing=val_post_transform,\n", | |
" key_val_metric=val_metrics,\n", | |
" val_handlers=val_handlers,\n", | |
")\n", | |
"\n", | |
"train_handlers = [\n", | |
" MetricLogger(evaluator=evaluator),\n", | |
" ValidationHandler(validator=evaluator, interval=1, epoch_level=True),\n", | |
"]\n", | |
"trainer = SupervisedTrainer(\n", | |
" device=device,\n", | |
" max_epochs=10,\n", | |
" train_data_loader=train_loader,\n", | |
" network=net,\n", | |
" optimizer=opt,\n", | |
" loss_function=loss,\n", | |
" key_train_metric=None,\n", | |
" train_handlers=train_handlers,\n", | |
")\n", | |
"\n", | |
"trainer.run()\n", | |
"\n", | |
"log_state = train_handlers[0].state_dict()\n", | |
"iteration, per_iter_losses = zip(*log_state[MetricLoggerKeys.LOSS]) # items are (iteration, value) pairs\n", | |
"viteration, per_epoch_dice = zip(*log_state[MetricLoggerKeys.METRICS][\"val_mean_dice\"])\n", | |
"viteration, per_epoch_dice_unreduced = zip(*log_state[MetricLoggerKeys.METRICS][\"mean_dice_unreduced\"])\n", | |
"\n", | |
"print(iteration)\n", | |
"print(len(per_iter_losses), torch.as_tensor(per_iter_losses))\n", | |
"print(viteration)\n", | |
"print(len(per_epoch_dice), torch.as_tensor(per_epoch_dice))\n", | |
"print(len(per_epoch_dice_unreduced), torch.stack(per_epoch_dice_unreduced))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "monai", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.14" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment