Created
April 14, 2023 22:38
-
-
Save calebrob6/983f90c3c752d2cc54e659a16ba01303 to your computer and use it in GitHub Desktop.
An example notebook that shows how to do SSL pretraining with the lightly library and TorchGeo
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import lightning.pytorch as pl\n", | |
"from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n", | |
"from lightning.pytorch.loggers import TensorBoardLogger\n", | |
"\n", | |
"import kornia.augmentation as K\n", | |
"\n", | |
"import torch\n", | |
"import torchvision\n", | |
"from torch import nn\n", | |
"import torch.nn.functional as F\n", | |
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n", | |
"from torchgeo.datasets import RESISC45\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"from lightly.data import LightlyDataset\n", | |
"from lightly.data.multi_view_collate import MultiViewCollate\n", | |
"from lightly.loss import NegativeCosineSimilarity\n", | |
"from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead\n", | |
"\n", | |
"from torchmetrics import MeanMetric\n", | |
"\n", | |
"from torchvision.models.feature_extraction import create_feature_extractor" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class TorchgeoLightlyDataset(LightlyDataset):\n", | |
" \"\"\"Wrapper class to adapt TorchGeo style datasets to a format that is expected by lightly's methods.\"\"\"\n", | |
"\n", | |
" def __init__(self, dataset, transform=None):\n", | |
" super().__init__(None, index_to_filename=None)\n", | |
" self.dataset = dataset\n", | |
" if transform is not None:\n", | |
" self.dataset.transforms = transform # NOTE: in TorchGeo we use `dataset.transforms`, not `.transform`\n", | |
"\n", | |
" def __getitem__(self, index: int):\n", | |
" fname = str(index)\n", | |
" sample = self.dataset.__getitem__(index)\n", | |
" return sample[\"image\"], sample[\"label\"], fname" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class ResnetFeatureExtractorWrapper(nn.Module):\n", | |
" def __init__(self, *args, **kwargs):\n", | |
" super().__init__()\n", | |
" self.model = torchvision.models.resnet18()\n", | |
" self.backbone = create_feature_extractor(self.model, return_nodes=[\"avgpool\"])\n", | |
"\n", | |
" def forward(self, x):\n", | |
" return self.backbone(x)[\"avgpool\"].squeeze()\n", | |
"\n", | |
"class FastSiam(pl.LightningModule):\n", | |
" def __init__(self, backbone, learning_rate=0.001, learning_rate_schedule_patience=10):\n", | |
" super().__init__()\n", | |
" \n", | |
" self.save_hyperparameters(ignore=['backbone'])\n", | |
" \n", | |
" self.backbone = backbone\n", | |
" self.projection_head = SimSiamProjectionHead(512, 512, 128)\n", | |
" self.prediction_head = SimSiamPredictionHead(128, 64, 128)\n", | |
" self.criterion = NegativeCosineSimilarity()\n", | |
" self.avg_output_std = 0\n", | |
"\n", | |
" def forward(self, x):\n", | |
" f = self.backbone(x).flatten(start_dim=1)\n", | |
" z = self.projection_head(f)\n", | |
" p = self.prediction_head(z)\n", | |
" z = z.detach()\n", | |
" return z, p\n", | |
"\n", | |
" def training_step(self, batch, batch_idx):\n", | |
" views, _, _ = batch\n", | |
" features = [self.forward(view) for view in views]\n", | |
" zs = torch.stack([z for z, _ in features])\n", | |
" ps = torch.stack([p for _, p in features])\n", | |
"\n", | |
" loss = 0.0\n", | |
" for i in range(len(views)):\n", | |
" mask = torch.arange(len(views), device=self.device) != i\n", | |
" loss += self.criterion(ps[i], torch.mean(zs[mask], dim=0)) / len(views)\n", | |
"\n", | |
" \n", | |
" output = ps[0].detach()\n", | |
" output = F.normalize(output, dim=1)\n", | |
" output_std = torch.std(output, 0)\n", | |
" output_std = output_std.mean()\n", | |
" \n", | |
" w = 0.9\n", | |
" self.avg_output_std = w * self.avg_output_std + (1 - w) * output_std.item()\n", | |
"\n", | |
" self.log(\"train_avg_output_std\", self.avg_output_std)\n", | |
" self.log(\"train_loss_ssl\", loss)\n", | |
" return loss\n", | |
"\n", | |
" def configure_optimizers(self):\n", | |
" optimizer = torch.optim.AdamW(\n", | |
" self.parameters(), lr=self.hparams.learning_rate\n", | |
" )\n", | |
" return {\n", | |
" \"optimizer\": optimizer,\n", | |
" \"lr_scheduler\": {\n", | |
" \"scheduler\": ReduceLROnPlateau(\n", | |
" optimizer,\n", | |
" patience=self.hparams.learning_rate_schedule_patience,\n", | |
" ),\n", | |
" \"monitor\": \"train_loss_ssl\",\n", | |
" },\n", | |
" }" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_transform_fn(num_views=4):\n", | |
" \n", | |
" mean = torch.tensor([0., 0., 0.])\n", | |
" std = torch.tensor([255., 255., 255.])\n", | |
" \n", | |
" train_augs = nn.Sequential(\n", | |
" K.Normalize(mean=mean, std=std),\n", | |
" K.RandomRotation(p=0.5, degrees=90),\n", | |
" K.RandomHorizontalFlip(p=0.5),\n", | |
" K.RandomVerticalFlip(p=0.5),\n", | |
" K.RandomSharpness(p=0.5),\n", | |
" K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n", | |
" )\n", | |
" def transform(sample):\n", | |
" image = sample[\"image\"]\n", | |
" sample[\"image\"] = [\n", | |
" train_augs(image).squeeze()\n", | |
" for _ in range(num_views)\n", | |
" ] \n", | |
" return sample\n", | |
" \n", | |
" return transform\n", | |
"\n", | |
"\n", | |
"transforms = get_transform_fn(4)\n", | |
"\n", | |
"ds = (\n", | |
" RESISC45(\"data/resisc45/\", split=\"train\", transforms=transforms) +\n", | |
" RESISC45(\"data/resisc45/\", split=\"test\", transforms=transforms)\n", | |
")\n", | |
"dataset = TorchgeoLightlyDataset(ds)\n", | |
"collate_fn = MultiViewCollate()\n", | |
"\n", | |
"dataloader = torch.utils.data.DataLoader(\n", | |
" dataset,\n", | |
" batch_size=64,\n", | |
" collate_fn=collate_fn,\n", | |
" shuffle=True,\n", | |
" drop_last=True,\n", | |
" num_workers=8,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"backbone = ResnetFeatureExtractorWrapper()\n", | |
"model = FastSiam(backbone, learning_rate=0.001)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"GPU available: True (cuda), used: True\n", | |
"TPU available: False, using: 0 TPU cores\n", | |
"IPU available: False, using: 0 IPUs\n", | |
"HPU available: False, using: 0 HPUs\n" | |
] | |
} | |
], | |
"source": [ | |
"experiment_name = \"resisc45-ssl2-2\"\n", | |
"tb_logger = TensorBoardLogger(\"logs/resisc45/\", name=experiment_name)\n", | |
"\n", | |
"trainer = pl.Trainer(\n", | |
" logger=[tb_logger],\n", | |
" max_epochs=100,\n", | |
" devices=[5],\n", | |
" accelerator=\"gpu\",\n", | |
" default_root_dir=f\"output/{experiment_name}/\"\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Missing logger folder: logs/resisc45/resisc45-ssl2-2\n", | |
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n", | |
"\n", | |
" | Name | Type | Params\n", | |
"------------------------------------------------------------------\n", | |
"0 | backbone | ResnetFeatureExtractorWrapper | 11.7 M\n", | |
"1 | projection_head | SimSiamProjectionHead | 591 K \n", | |
"2 | prediction_head | SimSiamPredictionHead | 16.6 K\n", | |
"3 | criterion | NegativeCosineSimilarity | 0 \n", | |
"------------------------------------------------------------------\n", | |
"12.3 M Trainable params\n", | |
"0 Non-trainable params\n", | |
"12.3 M Total params\n", | |
"49.192 Total estimated model params size (MB)\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "127da30461b24cb3b8508f8d2dce6a99", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Training: 0it [00:00, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"trainer.fit(model=model, train_dataloaders=dataloader)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# save the weights from the original resnet18\n", | |
"torch.save(backbone.model.state_dict(), \"resnet18-simsiam-resis45.pt\")" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment