Skip to content

Instantly share code, notes, and snippets.

@calebrob6
Created April 14, 2023 22:38
Show Gist options
  • Save calebrob6/983f90c3c752d2cc54e659a16ba01303 to your computer and use it in GitHub Desktop.
Save calebrob6/983f90c3c752d2cc54e659a16ba01303 to your computer and use it in GitHub Desktop.
An example notebook that shows how to do SSL pretraining with the lightly library and TorchGeo
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import lightning.pytorch as pl\n",
"from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n",
"from lightning.pytorch.loggers import TensorBoardLogger\n",
"\n",
"import kornia.augmentation as K\n",
"\n",
"import torch\n",
"import torchvision\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
"from torchgeo.datasets import RESISC45\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from lightly.data import LightlyDataset\n",
"from lightly.data.multi_view_collate import MultiViewCollate\n",
"from lightly.loss import NegativeCosineSimilarity\n",
"from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead\n",
"\n",
"from torchmetrics import MeanMetric\n",
"\n",
"from torchvision.models.feature_extraction import create_feature_extractor"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class TorchgeoLightlyDataset(LightlyDataset):\n",
" \"\"\"Wrapper class to adapt TorchGeo style datasets to a format that is expected by lightly's methods.\"\"\"\n",
"\n",
" def __init__(self, dataset, transform=None):\n",
" super().__init__(None, index_to_filename=None)\n",
" self.dataset = dataset\n",
" if transform is not None:\n",
" self.dataset.transforms = transform # NOTE: in TorchGeo we use `dataset.transforms`, not `.transform`\n",
"\n",
" def __getitem__(self, index: int):\n",
" fname = str(index)\n",
" sample = self.dataset.__getitem__(index)\n",
" return sample[\"image\"], sample[\"label\"], fname"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class ResnetFeatureExtractorWrapper(nn.Module):\n",
" def __init__(self, *args, **kwargs):\n",
" super().__init__()\n",
" self.model = torchvision.models.resnet18()\n",
" self.backbone = create_feature_extractor(self.model, return_nodes=[\"avgpool\"])\n",
"\n",
" def forward(self, x):\n",
" return self.backbone(x)[\"avgpool\"].squeeze()\n",
"\n",
"class FastSiam(pl.LightningModule):\n",
" def __init__(self, backbone, learning_rate=0.001, learning_rate_schedule_patience=10):\n",
" super().__init__()\n",
" \n",
" self.save_hyperparameters(ignore=['backbone'])\n",
" \n",
" self.backbone = backbone\n",
" self.projection_head = SimSiamProjectionHead(512, 512, 128)\n",
" self.prediction_head = SimSiamPredictionHead(128, 64, 128)\n",
" self.criterion = NegativeCosineSimilarity()\n",
" self.avg_output_std = 0\n",
"\n",
" def forward(self, x):\n",
" f = self.backbone(x).flatten(start_dim=1)\n",
" z = self.projection_head(f)\n",
" p = self.prediction_head(z)\n",
" z = z.detach()\n",
" return z, p\n",
"\n",
" def training_step(self, batch, batch_idx):\n",
" views, _, _ = batch\n",
" features = [self.forward(view) for view in views]\n",
" zs = torch.stack([z for z, _ in features])\n",
" ps = torch.stack([p for _, p in features])\n",
"\n",
" loss = 0.0\n",
" for i in range(len(views)):\n",
" mask = torch.arange(len(views), device=self.device) != i\n",
" loss += self.criterion(ps[i], torch.mean(zs[mask], dim=0)) / len(views)\n",
"\n",
" \n",
" output = ps[0].detach()\n",
" output = F.normalize(output, dim=1)\n",
" output_std = torch.std(output, 0)\n",
" output_std = output_std.mean()\n",
" \n",
" w = 0.9\n",
" self.avg_output_std = w * self.avg_output_std + (1 - w) * output_std.item()\n",
"\n",
" self.log(\"train_avg_output_std\", self.avg_output_std)\n",
" self.log(\"train_loss_ssl\", loss)\n",
" return loss\n",
"\n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.AdamW(\n",
" self.parameters(), lr=self.hparams.learning_rate\n",
" )\n",
" return {\n",
" \"optimizer\": optimizer,\n",
" \"lr_scheduler\": {\n",
" \"scheduler\": ReduceLROnPlateau(\n",
" optimizer,\n",
" patience=self.hparams.learning_rate_schedule_patience,\n",
" ),\n",
" \"monitor\": \"train_loss_ssl\",\n",
" },\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def get_transform_fn(num_views=4):\n",
" \n",
" mean = torch.tensor([0., 0., 0.])\n",
" std = torch.tensor([255., 255., 255.])\n",
" \n",
" train_augs = nn.Sequential(\n",
" K.Normalize(mean=mean, std=std),\n",
" K.RandomRotation(p=0.5, degrees=90),\n",
" K.RandomHorizontalFlip(p=0.5),\n",
" K.RandomVerticalFlip(p=0.5),\n",
" K.RandomSharpness(p=0.5),\n",
" K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n",
" )\n",
" def transform(sample):\n",
" image = sample[\"image\"]\n",
" sample[\"image\"] = [\n",
" train_augs(image).squeeze()\n",
" for _ in range(num_views)\n",
" ] \n",
" return sample\n",
" \n",
" return transform\n",
"\n",
"\n",
"transforms = get_transform_fn(4)\n",
"\n",
"ds = (\n",
" RESISC45(\"data/resisc45/\", split=\"train\", transforms=transforms) +\n",
" RESISC45(\"data/resisc45/\", split=\"test\", transforms=transforms)\n",
")\n",
"dataset = TorchgeoLightlyDataset(ds)\n",
"collate_fn = MultiViewCollate()\n",
"\n",
"dataloader = torch.utils.data.DataLoader(\n",
" dataset,\n",
" batch_size=64,\n",
" collate_fn=collate_fn,\n",
" shuffle=True,\n",
" drop_last=True,\n",
" num_workers=8,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"backbone = ResnetFeatureExtractorWrapper()\n",
"model = FastSiam(backbone, learning_rate=0.001)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n"
]
}
],
"source": [
"experiment_name = \"resisc45-ssl2-2\"\n",
"tb_logger = TensorBoardLogger(\"logs/resisc45/\", name=experiment_name)\n",
"\n",
"trainer = pl.Trainer(\n",
" logger=[tb_logger],\n",
" max_epochs=100,\n",
" devices=[5],\n",
" accelerator=\"gpu\",\n",
" default_root_dir=f\"output/{experiment_name}/\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Missing logger folder: logs/resisc45/resisc45-ssl2-2\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n",
"\n",
" | Name | Type | Params\n",
"------------------------------------------------------------------\n",
"0 | backbone | ResnetFeatureExtractorWrapper | 11.7 M\n",
"1 | projection_head | SimSiamProjectionHead | 591 K \n",
"2 | prediction_head | SimSiamPredictionHead | 16.6 K\n",
"3 | criterion | NegativeCosineSimilarity | 0 \n",
"------------------------------------------------------------------\n",
"12.3 M Trainable params\n",
"0 Non-trainable params\n",
"12.3 M Total params\n",
"49.192 Total estimated model params size (MB)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "127da30461b24cb3b8508f8d2dce6a99",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"trainer.fit(model=model, train_dataloaders=dataloader)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# save the weights from the original resnet18\n",
"torch.save(backbone.model.state_dict(), \"resnet18-simsiam-resis45.pt\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment