Skip to content

Instantly share code, notes, and snippets.

@wxyang007
Forked from calebrob6/LEVIR-CD+ example.ipynb
Created August 13, 2024 05:11
Show Gist options
  • Save wxyang007/222937d48a712e8b33e96e92e21146cc to your computer and use it in GitHub Desktop.
Save wxyang007/222937d48a712e8b33e96e92e21146cc to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "6a5b65df",
"metadata": {},
"source": [
"# LEVIR-CD+ change detection example notebook\n",
"\n",
"We start off by installing torchgeo. If you are running this on Colab, then you will need to restart your runtime after this step."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4627b902",
"metadata": {},
"outputs": [],
"source": [
"!pip install torchgeo"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "475f3715",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import torchgeo\n",
"from torchgeo.datasets import LEVIRCDPlus\n",
"from torchgeo.datasets.utils import unbind_samples\n",
"from torchgeo.trainers import SemanticSegmentationTask\n",
"from torchgeo.datamodules.utils import dataset_split\n",
"\n",
"import lightning.pytorch as pl\n",
"from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n",
"from lightning.pytorch import Trainer, seed_everything\n",
"from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger\n",
"from lightning.pytorch import LightningDataModule\n",
"\n",
"import torch\n",
"from torch.utils.data import DataLoader\n",
"import kornia.augmentation as K\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import torchvision\n",
"from torchvision.transforms import Compose\n",
"from tqdm import tqdm\n",
"\n",
"from sklearn.metrics import precision_score, recall_score"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2ae75c6f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('0.5.1', '2.1.3', '2.0.1+cu117')"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torchgeo.__version__, pl.__version__, torch.__version__"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "daedd8ce",
"metadata": {},
"outputs": [],
"source": [
"torch.cuda.is_available()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "0b012728",
"metadata": {},
"outputs": [],
"source": [
"# some experiment parameters\n",
"\n",
"experiment_name = \"experiment_test\"\n",
"experiment_dir = f\"results/{experiment_name}\"\n",
"os.makedirs(experiment_dir, exist_ok=True)\n",
"\n",
"batch_size = 8\n",
"learning_rate = 0.0001\n",
"gpu_id = 0\n",
"device = torch.device(f\"cuda:{gpu_id}\")\n",
"num_dataloader_workers = 2\n",
"patch_size = 256\n",
"val_split_pct = 0.1 # how much of our training set to hold out as a validation set"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ca211445",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Files already downloaded and verified\n",
"Files already downloaded and verified\n"
]
},
{
"data": {
"text/plain": [
"(637, 348)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Download the dataset and see how many images are in the train and test splits\n",
"\n",
"train_dataset = LEVIRCDPlus(root=\"data/LEVIRCDPlus\", split=\"train\", download=True, checksum=True)\n",
"test_dataset = LEVIRCDPlus(root=\"data/LEVIRCDPlus\", split=\"test\", download=True, checksum=True)\n",
"len(train_dataset), len(test_dataset)"
]
},
{
"cell_type": "markdown",
"id": "8d7e6981",
"metadata": {},
"source": [
"## Excersise 1\n",
"\n",
"Plot some examples from the `train_dataset` (note: torchgeo will help you out here)."
]
},
{
"cell_type": "markdown",
"id": "8127d129",
"metadata": {},
"source": [
"## Define a PyTorch Lightning module and datamodule\n",
"\n",
"PyTorch Lightning organizes the steps required for training deep learning models in `LightningModules`, and organizes the dataset handling to creating dataloaders in `LightningDataModules`. TorchGeo provides pre-built LightningDataModules for a handful of datasets, and pre-built \"trainers\" (i.e. LightningModules) for a variety of different types of tasks.\n",
"\n",
"For this tutorial, we will lightly extend TorchGeo's `SemanticSegmentationTask` (just to add some custom plotting code) and create a new LightningDataModule for the LEVIR-CD+ dataset."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "26f62ac5",
"metadata": {},
"outputs": [],
"source": [
"class CustomSemanticSegmentationTask(SemanticSegmentationTask):\n",
" \n",
" def plot(self, sample):\n",
" image1 = sample[\"image\"][:3]\n",
" image2 = sample[\"image\"][3:]\n",
" mask = sample[\"mask\"]\n",
" prediction = sample[\"prediction\"]\n",
"\n",
" fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(4 * 5, 5))\n",
" axs[0].imshow(image1.permute(1, 2, 0))\n",
" axs[0].axis(\"off\")\n",
" axs[1].imshow(image2.permute(1, 2, 0))\n",
" axs[1].axis(\"off\")\n",
" axs[2].imshow(mask)\n",
" axs[2].axis(\"off\")\n",
" axs[3].imshow(prediction)\n",
" axs[3].axis(\"off\")\n",
"\n",
" axs[0].set_title(\"Image 1\")\n",
" axs[1].set_title(\"Image 2\")\n",
" axs[2].set_title(\"Mask\")\n",
" axs[3].set_title(\"Prediction\")\n",
"\n",
" plt.tight_layout()\n",
" \n",
" return fig\n",
"\n",
" # The only difference between this code and the same from SemanticSegmentationTask is our redirect to use our own plotting function\n",
" def training_step(self, *args, **kwargs):\n",
" batch = args[0]\n",
" batch_idx = args[1]\n",
" \n",
" x = batch[\"image\"]\n",
" y = batch[\"mask\"]\n",
" y_hat = self.forward(x)\n",
" y_hat_hard = y_hat.argmax(dim=1)\n",
"\n",
" loss = self.criterion(y_hat, y)\n",
"\n",
" self.log(\"train_loss\", loss, on_step=True, on_epoch=False)\n",
" self.train_metrics(y_hat_hard, y)\n",
"\n",
" if batch_idx < 10:\n",
" batch[\"prediction\"] = y_hat_hard\n",
" for key in [\"image\", \"mask\", \"prediction\"]:\n",
" batch[key] = batch[key].cpu()\n",
" sample = unbind_samples(batch)[0]\n",
" fig = self.plot(sample)\n",
" summary_writer = self.logger.experiment\n",
" summary_writer.add_figure(\n",
" f\"image/train/{batch_idx}\", fig, global_step=self.global_step\n",
" )\n",
" plt.close()\n",
" \n",
" return loss\n",
" \n",
" # The only difference between this code and the same from SemanticSegmentationTask is our redirect to use our own plotting function\n",
" def validation_step(self, *args, **kwargs):\n",
" batch = args[0]\n",
" batch_idx = args[1]\n",
" x = batch[\"image\"]\n",
" y = batch[\"mask\"]\n",
" y_hat = self.forward(x)\n",
" y_hat_hard = y_hat.argmax(dim=1)\n",
"\n",
" loss = self.criterion(y_hat, y)\n",
"\n",
" self.log(\"val_loss\", loss, on_step=False, on_epoch=True)\n",
" self.val_metrics(y_hat_hard, y)\n",
"\n",
" if batch_idx < 10:\n",
" batch[\"prediction\"] = y_hat_hard\n",
" for key in [\"image\", \"mask\", \"prediction\"]:\n",
" batch[key] = batch[key].cpu()\n",
" sample = unbind_samples(batch)[0]\n",
" fig = self.plot(sample)\n",
" summary_writer = self.logger.experiment\n",
" summary_writer.add_figure(\n",
" f\"image/val/{batch_idx}\", fig, global_step=self.global_step\n",
" )\n",
" plt.close()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f420887f",
"metadata": {},
"outputs": [],
"source": [
"class LEVIRCDPlusDataModule(pl.LightningDataModule):\n",
"\n",
" def __init__(\n",
" self,\n",
" batch_size=32,\n",
" num_workers=0,\n",
" val_split_pct=0.2,\n",
" patch_size=(256, 256),\n",
" **kwargs,\n",
" ):\n",
" super().__init__()\n",
" self.batch_size = batch_size\n",
" self.num_workers = num_workers\n",
" self.val_split_pct = val_split_pct\n",
" self.patch_size = patch_size\n",
" self.kwargs = kwargs\n",
"\n",
" def on_after_batch_transfer(\n",
" self, batch, batch_idx\n",
" ):\n",
" if (\n",
" hasattr(self, \"trainer\")\n",
" and self.trainer is not None\n",
" and hasattr(self.trainer, \"training\")\n",
" and self.trainer.training\n",
" ):\n",
" # Kornia expects masks to be floats with a channel dimension\n",
" x = batch[\"image\"]\n",
" y = batch[\"mask\"].float().unsqueeze(1)\n",
"\n",
" train_augmentations = K.AugmentationSequential(\n",
" K.RandomRotation(p=0.5, degrees=90),\n",
" K.RandomHorizontalFlip(p=0.5),\n",
" K.RandomVerticalFlip(p=0.5),\n",
" K.RandomCrop(self.patch_size),\n",
" K.RandomSharpness(p=0.5),\n",
" data_keys=[\"input\", \"mask\"],\n",
" )\n",
" x, y = train_augmentations(x, y)\n",
"\n",
" # torchmetrics expects masks to be longs without a channel dimension\n",
" batch[\"image\"] = x\n",
" batch[\"mask\"] = y.squeeze(1).long()\n",
"\n",
" return batch\n",
" \n",
" def preprocess(self, sample):\n",
" sample[\"image\"] = (sample[\"image\"] / 255.0).float()\n",
" sample[\"image\"] = torch.flatten(sample[\"image\"], 0, 1)\n",
" sample[\"mask\"] = sample[\"mask\"].long()\n",
" return sample\n",
"\n",
" def prepare_data(self):\n",
" LEVIRCDPlus(split=\"train\", **self.kwargs)\n",
"\n",
" def setup(self, stage=None):\n",
" train_transforms = Compose([self.preprocess])\n",
" test_transforms = Compose([self.preprocess])\n",
"\n",
" train_dataset = LEVIRCDPlus(\n",
" split=\"train\", transforms=train_transforms, **self.kwargs\n",
" )\n",
"\n",
" if self.val_split_pct > 0.0:\n",
" self.train_dataset, self.val_dataset, _ = dataset_split(\n",
" train_dataset, val_pct=self.val_split_pct, test_pct=0.0\n",
" )\n",
" else:\n",
" self.train_dataset = train_dataset\n",
" self.val_dataset = train_dataset\n",
"\n",
" self.test_dataset = LEVIRCDPlus(\n",
" split=\"test\", transforms=test_transforms, **self.kwargs\n",
" )\n",
"\n",
" def train_dataloader(self):\n",
" return DataLoader(\n",
" self.train_dataset,\n",
" batch_size=self.batch_size,\n",
" num_workers=self.num_workers,\n",
" shuffle=True,\n",
" )\n",
"\n",
" def val_dataloader(self):\n",
" return DataLoader(\n",
" self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False\n",
" )\n",
"\n",
" def test_dataloader(self):\n",
" return DataLoader(\n",
" self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "d221e5db",
"metadata": {},
"source": [
"## Setting up a training run"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "97a5ff80",
"metadata": {},
"outputs": [],
"source": [
"datamodule = LEVIRCDPlusDataModule(\n",
" root=\"data/LEVIRCDPlus\",\n",
" batch_size=batch_size,\n",
" num_workers=num_dataloader_workers,\n",
" val_split_pct=val_split_pct,\n",
" patch_size=(patch_size, patch_size),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "82b472f5",
"metadata": {},
"outputs": [],
"source": [
"task = CustomSemanticSegmentationTask(\n",
" model=\"unet\",\n",
" backbone=\"resnet18\",\n",
" weights=True,\n",
" in_channels=6,\n",
" num_classes=2,\n",
" loss=\"ce\",\n",
" ignore_index=None,\n",
" lr=learning_rate,\n",
" patience=10\n",
")\n",
"\n",
"checkpoint_callback = ModelCheckpoint(\n",
" monitor=\"val_loss\",\n",
" dirpath=experiment_dir,\n",
" save_top_k=1,\n",
" save_last=True,\n",
")\n",
"\n",
"early_stopping_callback = EarlyStopping(\n",
" monitor=\"val_loss\",\n",
" min_delta=0.00,\n",
" patience=10,\n",
")\n",
"\n",
"tb_logger = TensorBoardLogger(\n",
" save_dir=\"logs/\",\n",
" name=experiment_name\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e54642fd",
"metadata": {},
"outputs": [],
"source": [
"%load_ext tensorboard"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "94fe9c6d",
"metadata": {},
"outputs": [],
"source": [
"%tensorboard --logdir logs/"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6fc5259c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n",
"\n",
" | Name | Type | Params\n",
"---------------------------------------------------\n",
"0 | model | Unet | 14.3 M\n",
"1 | loss | CrossEntropyLoss | 0 \n",
"2 | train_metrics | MetricCollection | 0 \n",
"3 | val_metrics | MetricCollection | 0 \n",
"4 | test_metrics | MetricCollection | 0 \n",
"---------------------------------------------------\n",
"14.3 M Trainable params\n",
"0 Non-trainable params\n",
"14.3 M Total params\n",
"57.351 Total estimated model params size (MB)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Sanity Checking: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/calebrobinson/.conda/envs/test/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1892: PossibleUserWarning: The number of training batches (18) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
" rank_zero_warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a15abdbc468b44d0b1a43a18e285ae95",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f2c8514f2d4a41e09b3bc88af1b40887",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"trainer = pl.Trainer(\n",
" callbacks=[checkpoint_callback, early_stopping_callback],\n",
" logger=[tb_logger],\n",
" default_root_dir=experiment_dir,\n",
" min_epochs=10,\n",
" max_epochs=200,\n",
" accelerator='gpu',\n",
" devices=[gpu_id]\n",
")\n",
"\n",
"_ = trainer.fit(model=task, datamodule=datamodule)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2cfacd81",
"metadata": {},
"outputs": [],
"source": [
"trainer.test(model=task, datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"id": "346e4afe",
"metadata": {},
"source": [
"## Custom test step to compute the precision, recall, and F1 metrics"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "b61db9fb",
"metadata": {},
"outputs": [],
"source": [
"# Example of how to load a trained task from a checkpoint file\n",
"# task = CustomSemanticSegmentationTask.load_from_checkpoint(\"results/...\")\n",
"# datamodule.setup(\"test\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c9b7a93c",
"metadata": {},
"outputs": [],
"source": [
"model = task.model.to(device).eval()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0e545e06",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 44/44 [00:21<00:00, 2.04it/s]\n"
]
}
],
"source": [
"y_preds = []\n",
"y_trues = []\n",
"for batch in tqdm(datamodule.test_dataloader()):\n",
" images = batch[\"image\"].to(device)\n",
" y_trues.append(batch[\"mask\"].numpy().ravel()[::500])\n",
" with torch.inference_mode():\n",
" y_pred = model(images).argmax(dim=1).cpu().numpy().ravel()[::500]\n",
" y_preds.append(y_pred)\n",
"\n",
"y_preds = np.concatenate(y_preds)\n",
"y_trues = np.concatenate(y_trues)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "8b5a6975",
"metadata": {},
"outputs": [],
"source": [
"precision = precision_score(y_trues, y_preds)\n",
"recall = recall_score(y_trues, y_preds)\n",
"f1 = 2 * (precision * recall) / (precision + recall)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "bf25b1d4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.7234695667426767, 0.5552638664512655, 0.6283037550460812)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"precision, recall, f1"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment