Created
February 22, 2022 10:38
-
-
Save amaarora/e4346adde3225645f96ccf22a3267cc1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "3344492e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# get dataset\n", | |
"# !mkdir data && cd data \n", | |
"# !wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz\n", | |
"# !tar -xvf imagenette2-160.tgz" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "047564f0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"import torch\n", | |
"import torchvision\n", | |
"import timm\n", | |
"import torch.nn as nn\n", | |
"from tqdm.notebook import tqdm\n", | |
"import albumentations\n", | |
"from torchvision import transforms\n", | |
"import numpy as np \n", | |
"import os\n", | |
"\n", | |
"# set logging\n", | |
"import logging\n", | |
"logging.getLogger().setLevel(logging.INFO)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "dc1f748e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"IMG_SIZE = 160 \n", | |
"MODEL_NAME = \"resnet34\"\n", | |
"LR = 1e-4\n", | |
"EPOCHS = 5" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "cf6dc7d2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_aug = transforms.Compose(\n", | |
" [\n", | |
" transforms.RandomCrop(IMG_SIZE),\n", | |
" transforms.RandomHorizontalFlip(p=0.5),\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n", | |
" ]\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "ac158b53", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"val_aug = transforms.Compose(\n", | |
" [\n", | |
" transforms.CenterCrop(IMG_SIZE),\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n", | |
" ]\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "d7860098", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(Compose(\n", | |
" RandomCrop(size=(160, 160), padding=None)\n", | |
" RandomHorizontalFlip(p=0.5)\n", | |
" ToTensor()\n", | |
" Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))\n", | |
" ),\n", | |
" Compose(\n", | |
" CenterCrop(size=(160, 160))\n", | |
" ToTensor()\n", | |
" Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))\n", | |
" ))" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_aug, val_aug" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "7817c33f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class CheckpointSaver:\n", | |
" def __init__(self, dirpath, decreasing=True, top_n=5):\n", | |
" \"\"\"\n", | |
" dirpath: Directory path where to store all model weights \n", | |
" decreasing: If decreasing is `True`, then lower metric is better\n", | |
" top_n: Total number of models to track based on validation metric value\n", | |
" \"\"\"\n", | |
" if not os.path.exists(dirpath): os.makedirs(dirpath)\n", | |
" self.dirpath = dirpath\n", | |
" self.top_n = top_n \n", | |
" self.decreasing = decreasing\n", | |
" self.top_model_paths = []\n", | |
" self.best_metric_val = np.Inf if decreasing else -np.Inf\n", | |
" \n", | |
" def __call__(self, model, epoch, metric_val):\n", | |
" model_path = os.path.join(self.dirpath, model.__class__.__name__ + f'_epoch{epoch}.pt')\n", | |
" save = metric_val<self.best_metric_val if self.decreasing else metric_val>self.best_metric_val\n", | |
" if save: \n", | |
" logging.info(f\"Current metric value better than {metric_val} better than best {self.best_metric_val}, saving model at {model_path}\")\n", | |
" self.best_metric_val = metric_val\n", | |
" torch.save(model.state_dict(), model_path)\n", | |
" self.top_model_paths.append({'path': model_path, 'score': metric_val})\n", | |
" self.top_model_paths = sorted(self.top_model_paths, key=lambda o: o['score'], reverse=not self.decreasing)\n", | |
" if len(self.top_model_paths)>self.top_n: \n", | |
" self.cleanup()\n", | |
" \n", | |
" def cleanup(self):\n", | |
" to_remove = self.top_model_paths[self.top_n:]\n", | |
" logging.info(f\"Removing extra models.. {to_remove}\")\n", | |
" for o in to_remove:\n", | |
" os.remove(o['path'])\n", | |
" self.top_model_paths = self.top_model_paths[:self.top_n]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "e1caf342", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def train_fn(model, train_data_loader, optimizer, epoch, device='cuda'):\n", | |
" model.train()\n", | |
" fin_loss = 0.0\n", | |
" tk = tqdm(train_data_loader, desc=\"Epoch\" + \" [TRAIN] \" + str(epoch + 1))\n", | |
"\n", | |
" for t, data in enumerate(tk):\n", | |
" data[0] = data[0].to(device)\n", | |
" data[1] = data[1].to(device)\n", | |
"\n", | |
" optimizer.zero_grad()\n", | |
" out = model(data[0])\n", | |
" loss = nn.CrossEntropyLoss()(out, data[1])\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
"\n", | |
" fin_loss += loss.item()\n", | |
" tk.set_postfix(\n", | |
" {\n", | |
" \"loss\": \"%.6f\" % float(fin_loss / (t + 1)),\n", | |
" \"LR\": optimizer.param_groups[0][\"lr\"],\n", | |
" }\n", | |
" )\n", | |
" return fin_loss / len(train_data_loader), optimizer.param_groups[0][\"lr\"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "565dfa4d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def eval_fn(model, eval_data_loader, epoch, device='cuda'):\n", | |
" model.eval()\n", | |
" fin_loss = 0.0\n", | |
" tk = tqdm(eval_data_loader, desc=\"Epoch\" + \" [VALID] \" + str(epoch + 1))\n", | |
"\n", | |
" with torch.no_grad():\n", | |
" for t, data in enumerate(tk):\n", | |
" data[0] = data[0].to(device)\n", | |
" data[1] = data[1].to(device)\n", | |
" out = model(data[0])\n", | |
" loss = nn.CrossEntropyLoss()(out, data[1])\n", | |
" fin_loss += loss.item()\n", | |
" tk.set_postfix({\"loss\": \"%.6f\" % float(fin_loss / (t + 1))})\n", | |
" return fin_loss / len(eval_data_loader)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "3aa6a0ac", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def train(train_dir, test_dir):\n", | |
" train_dataset = torchvision.datasets.ImageFolder(\n", | |
" train_dir, transform=train_aug\n", | |
" )\n", | |
" eval_dataset = torchvision.datasets.ImageFolder(\n", | |
" test_dir, transform=val_aug\n", | |
" )\n", | |
" train_dataloader = torch.utils.data.DataLoader(\n", | |
" train_dataset,\n", | |
" batch_size=128,\n", | |
" shuffle=True,\n", | |
" num_workers=4\n", | |
" )\n", | |
" eval_dataloader = torch.utils.data.DataLoader(\n", | |
" eval_dataset, batch_size=64, num_workers=4\n", | |
" )\n", | |
"\n", | |
" # model\n", | |
" model = timm.create_model(MODEL_NAME, pretrained=True)\n", | |
" model = model.cuda()\n", | |
"\n", | |
" # optimizer\n", | |
" optimizer = torch.optim.Adam(model.parameters(), lr=LR)\n", | |
"\n", | |
" # checkpoint saver\n", | |
" checkpoint_saver = CheckpointSaver(dirpath='./model_weights', decreasing=True, top_n=1)\n", | |
" for epoch in range(EPOCHS):\n", | |
" avg_loss_train, lr = train_fn(\n", | |
" model, train_dataloader, optimizer, epoch, device='cuda'\n", | |
" )\n", | |
" avg_loss_eval = eval_fn(model, eval_dataloader, epoch, device='cuda')\n", | |
" checkpoint_saver(model, epoch, avg_loss_eval)\n", | |
" print(\n", | |
" f\"EPOCH = {epoch} | TRAIN_LOSS = {avg_loss_train} | EVAL_LOSS = {avg_loss_eval}\"\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "b13a5dd7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "39c36eaaa13041d98fbb66848b68ec36", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch [TRAIN] 1: 0%| | 0/74 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "f8ea52598bb04371967d1454b288f9b5", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch [VALID] 1: 0%| | 0/62 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:root:Current metric value better than 0.1954631515958857 better than best inf, saving model at ./model_weights/ResNet_epoch0.pt\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"EPOCH = 0 | TRAIN_LOSS = 1.3544871056502736 | EVAL_LOSS = 0.1954631515958857\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "1667f9b528524a368043fbdabfad2abf", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch [TRAIN] 2: 0%| | 0/74 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "1c24b86d8aeb4e15a64f4d2c89efc551", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch [VALID] 2: 0%| | 0/62 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:root:Current metric value better than 0.15000865837529062 better than best 0.1954631515958857, saving model at ./model_weights/ResNet_epoch1.pt\n", | |
"INFO:root:Removing extra models.. [{'path': './model_weights/ResNet_epoch0.pt', 'score': 0.1954631515958857}]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"EPOCH = 1 | TRAIN_LOSS = 0.11298705174310787 | EVAL_LOSS = 0.15000865837529062\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "e34713e9f7344440a4686310476d985f", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch [TRAIN] 3: 0%| | 0/74 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "2d52768e16114ae79f9204e6fd07731e", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch [VALID] 3: 0%| | 0/62 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:root:Current metric value better than 0.1338667555208949 better than best 0.15000865837529062, saving model at ./model_weights/ResNet_epoch2.pt\n", | |
"INFO:root:Removing extra models.. [{'path': './model_weights/ResNet_epoch1.pt', 'score': 0.15000865837529062}]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"EPOCH = 2 | TRAIN_LOSS = 0.053521369734930026 | EVAL_LOSS = 0.1338667555208949\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "33f3744f46234b1190d63e05eac2f638", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch [TRAIN] 4: 0%| | 0/74 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "6f17c6f08ac94326b17e699b42d359a5", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch [VALID] 4: 0%| | 0/62 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:root:Current metric value better than 0.12575743053551583 better than best 0.1338667555208949, saving model at ./model_weights/ResNet_epoch3.pt\n", | |
"INFO:root:Removing extra models.. [{'path': './model_weights/ResNet_epoch2.pt', 'score': 0.1338667555208949}]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"EPOCH = 3 | TRAIN_LOSS = 0.04176709730165532 | EVAL_LOSS = 0.12575743053551583\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "39e1ee19189b41d3b2390cda491f9399", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch [TRAIN] 5: 0%| | 0/74 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "48f96c8f0e864b2f9f8d65efcda42ce3", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Epoch [VALID] 5: 0%| | 0/62 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"EPOCH = 4 | TRAIN_LOSS = 0.027766215419900174 | EVAL_LOSS = 0.12607890976810707\n" | |
] | |
} | |
], | |
"source": [ | |
"train(train_dir='./data/imagenette2-160/train/', test_dir='./data/imagenette2-160/val/')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "b2d2b4bd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"gist": { | |
"data": { | |
"description": "reports/How to save all your trained model weights locally after every epoch.ipynb", | |
"public": true | |
}, | |
"id": "" | |
}, | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.7" | |
}, | |
"toc": { | |
"base_numbering": 1, | |
"nav_menu": {}, | |
"number_sections": true, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"title_cell": "Table of Contents", | |
"title_sidebar": "Contents", | |
"toc_cell": false, | |
"toc_position": {}, | |
"toc_section_display": true, | |
"toc_window_display": false | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment