Created
March 1, 2022 12:51
-
-
Save amaarora/9b867f1868f319b3f2e6adb6bfe2373e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "56a0919e", | |
"metadata": {}, | |
"source": [ | |
"# How to save all your trained model weights locally after every epoch\n", | |
"> This notebook provides working code for a Checkpoint Saver for the report - [How to save all your trained model weights locally after every epoch](https://wandb.ai/amanarora/melanoma/reports/How-to-save-all-your-trained-model-weights-locally-after-every-epoch--VmlldzoxNTkzNjY1)." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "7be546a8", | |
"metadata": {}, | |
"source": [ | |
"## Download the Imagenette dataset\n", | |
"> Uncomment the first time when you are running this notebook." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "cfbf5f4d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# get dataset\n", | |
"# !mkdir data && cd data \n", | |
"# !wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz\n", | |
"# !tar -xvf imagenette2-160.tgz" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "53076c24", | |
"metadata": {}, | |
"source": [ | |
"## Imports" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "8ed12334", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"import torch\n", | |
"import torchvision\n", | |
"import timm\n", | |
"import torch.nn as nn\n", | |
"from tqdm.notebook import tqdm\n", | |
"import albumentations\n", | |
"from torchvision import transforms\n", | |
"import numpy as np \n", | |
"import os\n", | |
"import wandb \n", | |
"\n", | |
"# set logging\n", | |
"import logging\n", | |
"logging.getLogger().setLevel(logging.INFO)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "ff8ad7d2", | |
"metadata": {}, | |
"source": [ | |
"## Config" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "7cf91300", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"IMG_SIZE = 160 \n", | |
"MODEL_NAME = \"resnet34\"\n", | |
"LR = 1e-4\n", | |
"EPOCHS = 5" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "3f867f1e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_aug = transforms.Compose(\n", | |
" [\n", | |
" transforms.RandomCrop(IMG_SIZE),\n", | |
" transforms.RandomHorizontalFlip(p=0.5),\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n", | |
" ]\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "56a3e4df", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"val_aug = transforms.Compose(\n", | |
" [\n", | |
" transforms.CenterCrop(IMG_SIZE),\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n", | |
" ]\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9fd4b81e", | |
"metadata": {}, | |
"source": [ | |
"Below, we initialize the Weights and Biases experiment by passing in the config values. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "36cfc6c8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mamanarora\u001b[0m (use `wandb login --relogin` to force relogin)\n", | |
"\u001b[34m\u001b[1mwandb\u001b[0m: wandb version 0.12.10 is available! To upgrade, please run:\n", | |
"\u001b[34m\u001b[1mwandb\u001b[0m: $ pip install wandb --upgrade\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"\n", | |
" Syncing run <strong><a href=\"https://wandb.ai/amanarora/melanoma-artifact/runs/33vyx0zh\" target=\"_blank\">eager-frost-1</a></strong> to <a href=\"https://wandb.ai/amanarora/melanoma-artifact\" target=\"_blank\">Weights & Biases</a> (<a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">docs</a>).<br/>\n", | |
"\n", | |
" " | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"run = wandb.init(project=\"melanoma-artifact\", config={\n", | |
" 'image size': IMG_SIZE, \n", | |
" 'model name': MODEL_NAME, \n", | |
" 'learning rate': LR, \n", | |
" 'epochs': EPOCHS, \n", | |
" 'training augmentation': train_aug, \n", | |
" 'valid augmentation': val_aug\n", | |
"})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "f372a6c7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(Compose(\n", | |
" RandomCrop(size=(160, 160), padding=None)\n", | |
" RandomHorizontalFlip(p=0.5)\n", | |
" ToTensor()\n", | |
" Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))\n", | |
" ),\n", | |
" Compose(\n", | |
" CenterCrop(size=(160, 160))\n", | |
" ToTensor()\n", | |
" Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))\n", | |
" ))" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_aug, val_aug" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c2a46514", | |
"metadata": {}, | |
"source": [ | |
"## Checkpoint Saver with W&B artifacts integration\n", | |
"> For a complete explanation of how the below code works, please refer to the report - [How to save all your trained model weights locally after every epoch](https://wandb.ai/amanarora/melanoma/reports/How-to-save-all-your-trained-model-weights-locally-after-every-epoch--VmlldzoxNTkzNjY1)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "a6293500", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class CheckpointSaver:\n", | |
" def __init__(self, dirpath, decreasing=True, top_n=5):\n", | |
" \"\"\"\n", | |
" dirpath: Directory path where to store all model weights \n", | |
" decreasing: If decreasing is `True`, then lower metric is better\n", | |
" top_n: Total number of models to track based on validation metric value\n", | |
" \"\"\"\n", | |
" if not os.path.exists(dirpath): os.makedirs(dirpath)\n", | |
" self.dirpath = dirpath\n", | |
" self.top_n = top_n \n", | |
" self.decreasing = decreasing\n", | |
" self.top_model_paths = []\n", | |
" self.best_metric_val = np.Inf if decreasing else -np.Inf\n", | |
" \n", | |
" def __call__(self, model, epoch, metric_val):\n", | |
" model_path = os.path.join(self.dirpath, model.__class__.__name__ + f'_epoch{epoch}.pt')\n", | |
" save = metric_val<self.best_metric_val if self.decreasing else metric_val>self.best_metric_val\n", | |
" if save: \n", | |
" logging.info(f\"Current metric value better than {metric_val} better than best {self.best_metric_val}, saving model at {model_path}, & logging model weights to W&B.\")\n", | |
" self.best_metric_val = metric_val\n", | |
" torch.save(model.state_dict(), model_path)\n", | |
" self.log_artifact(f'model-ckpt-epoch-{epoch}.pt', model_path, metric_val)\n", | |
" self.top_model_paths.append({'path': model_path, 'score': metric_val})\n", | |
" self.top_model_paths = sorted(self.top_model_paths, key=lambda o: o['score'], reverse=not self.decreasing)\n", | |
" if len(self.top_model_paths)>self.top_n: \n", | |
" self.cleanup()\n", | |
" \n", | |
" def log_artifact(self, filename, model_path, metric_val):\n", | |
" artifact = wandb.Artifact(filename, type='model', metadata={'Validation score': metric_val})\n", | |
" artifact.add_file(model_path)\n", | |
" wandb.run.log_artifact(artifact) \n", | |
" \n", | |
" def cleanup(self):\n", | |
" to_remove = self.top_model_paths[self.top_n:]\n", | |
" logging.info(f\"Removing extra models.. {to_remove}\")\n", | |
" for o in to_remove:\n", | |
" os.remove(o['path'])\n", | |
" self.top_model_paths = self.top_model_paths[:self.top_n]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1132517a", | |
"metadata": {}, | |
"source": [ | |
"## Model training" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "d1868e77", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def train_fn(model, train_data_loader, optimizer, epoch, device='cuda'):\n", | |
" model.train()\n", | |
" fin_loss = 0.0\n", | |
" tk = tqdm(train_data_loader, desc=\"Epoch\" + \" [TRAIN] \" + str(epoch + 1))\n", | |
"\n", | |
" for t, data in enumerate(tk):\n", | |
" data[0] = data[0].to(device)\n", | |
" data[1] = data[1].to(device)\n", | |
"\n", | |
" optimizer.zero_grad()\n", | |
" out = model(data[0])\n", | |
" loss = nn.CrossEntropyLoss()(out, data[1])\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
"\n", | |
" fin_loss += loss.item()\n", | |
" tk.set_postfix(\n", | |
" {\n", | |
" \"loss\": \"%.6f\" % float(fin_loss / (t + 1)),\n", | |
" \"LR\": optimizer.param_groups[0][\"lr\"],\n", | |
" }\n", | |
" )\n", | |
" return fin_loss / len(train_data_loader), optimizer.param_groups[0][\"lr\"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "64e26c13", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def eval_fn(model, eval_data_loader, epoch, device='cuda'):\n", | |
" model.eval()\n", | |
" fin_loss = 0.0\n", | |
" tk = tqdm(eval_data_loader, desc=\"Epoch\" + \" [VALID] \" + str(epoch + 1))\n", | |
"\n", | |
" with torch.no_grad():\n", | |
" for t, data in enumerate(tk):\n", | |
" data[0] = data[0].to(device)\n", | |
" data[1] = data[1].to(device)\n", | |
" out = model(data[0])\n", | |
" loss = nn.CrossEntropyLoss()(out, data[1])\n", | |
" fin_loss += loss.item()\n", | |
" tk.set_postfix({\"loss\": \"%.6f\" % float(fin_loss / (t + 1))})\n", | |
" return fin_loss / len(eval_data_loader)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "f51bb9a3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def train(train_dir, test_dir):\n", | |
" train_dataset = torchvision.datasets.ImageFolder(\n", | |
" train_dir, transform=train_aug\n", | |
" )\n", | |
" eval_dataset = torchvision.datasets.ImageFolder(\n", | |
" test_dir, transform=val_aug\n", | |
" )\n", | |
" train_dataloader = torch.utils.data.DataLoader(\n", | |
" train_dataset,\n", | |
" batch_size=128,\n", | |
" shuffle=True,\n", | |
" num_workers=4\n", | |
" )\n", | |
" eval_dataloader = torch.utils.data.DataLoader(\n", | |
" eval_dataset, batch_size=64, num_workers=4\n", | |
" )\n", | |
"\n", | |
" # model\n", | |
" model = timm.create_model(MODEL_NAME, pretrained=True)\n", | |
" model = model.cuda()\n", | |
"\n", | |
" # optimizer\n", | |
" optimizer = torch.optim.Adam(model.parameters(), lr=LR)\n", | |
"\n", | |
" # checkpoint saver\n", | |
" checkpoint_saver = CheckpointSaver(dirpath='./model_weights', decreasing=True, top_n=5)\n", | |
" for epoch in range(EPOCHS):\n", | |
" avg_loss_train, lr = train_fn(\n", | |
" model, train_dataloader, optimizer, epoch, device='cuda'\n", | |
" )\n", | |
" avg_loss_eval = eval_fn(model, eval_dataloader, epoch, device='cuda')\n", | |
" checkpoint_saver(model, epoch, avg_loss_eval)\n", | |
" wandb.run.log({'epoch': epoch, 'train loss': avg_loss_train, 'eval loss': avg_loss_eval})\n", | |
" print(\n", | |
" f\"EPOCH = {epoch} | TRAIN_LOSS = {avg_loss_train} | EVAL_LOSS = {avg_loss_eval}\"\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "65b86867", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "b3528fa6f1374456a4fe94a373173806", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch [TRAIN] 1: 0%| | 0/74 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "7d9173db65c947c8983f6ccaa44365de", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch [VALID] 1: 0%| | 0/62 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:root:Current metric value better than 0.19492560036240086 better than best inf, saving model at ./model_weights/ResNet_epoch0.pt, & logging model weights to W&B.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"EPOCH = 0 | TRAIN_LOSS = 1.375664096527003 | EVAL_LOSS = 0.19492560036240086\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "9bf581db089147e59b607cacea9bab23", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch [TRAIN] 2: 0%| | 0/74 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "315a7ec27aac4c62beb1fa46bd0f226e", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch [VALID] 2: 0%| | 0/62 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:root:Current metric value better than 0.1360785181619107 better than best 0.19492560036240086, saving model at ./model_weights/ResNet_epoch1.pt, & logging model weights to W&B.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"EPOCH = 1 | TRAIN_LOSS = 0.11170286700330875 | EVAL_LOSS = 0.1360785181619107\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "ad6b12add6c44de6b802cbbaf05e08d8", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch [TRAIN] 3: 0%| | 0/74 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "cee955c8ff1d401b9a1f7287c451b8d4", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch [VALID] 3: 0%| | 0/62 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"EPOCH = 2 | TRAIN_LOSS = 0.04712816304527223 | EVAL_LOSS = 0.1487596299078676\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "d30cb6aa313d453393dc2a85dceb6885", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch [TRAIN] 4: 0%| | 0/74 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "e4f212615e8f488c914e8a2e9ed6676d", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch [VALID] 4: 0%| | 0/62 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:root:Current metric value better than 0.12546487889748306 better than best 0.1360785181619107, saving model at ./model_weights/ResNet_epoch3.pt, & logging model weights to W&B.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"EPOCH = 3 | TRAIN_LOSS = 0.034359857650800935 | EVAL_LOSS = 0.12546487889748306\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "c1061c739866458ea7d8cbba27393b64", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch [TRAIN] 5: 0%| | 0/74 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "18f689461ea84e9695c423d09fbd6b3c", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch [VALID] 5: 0%| | 0/62 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"EPOCH = 4 | TRAIN_LOSS = 0.027094046788202045 | EVAL_LOSS = 0.13780122134278738\n" | |
] | |
} | |
], | |
"source": [ | |
"train(train_dir='./data/imagenette2-160/train/', test_dir='./data/imagenette2-160/val/')" | |
] | |
} | |
], | |
"metadata": { | |
"gist": { | |
"data": { | |
"description": "reports/How to save all your trained model weights locally after every epoch.ipynb", | |
"public": true | |
}, | |
"id": "" | |
}, | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.7" | |
}, | |
"toc": { | |
"base_numbering": 1, | |
"nav_menu": {}, | |
"number_sections": true, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"title_cell": "Table of Contents", | |
"title_sidebar": "Contents", | |
"toc_cell": false, | |
"toc_position": {}, | |
"toc_section_display": true, | |
"toc_window_display": false | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment