Created
January 8, 2022 20:46
-
-
Save SharanSMenon/53bc826c286aafb1f097ae4511f98ec6 to your computer and use it in GitHub Desktop.
An implementation of Retinanet.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "a2c3541f-8c5f-490d-8ada-2d9264a72074", | |
"metadata": {}, | |
"source": [ | |
"# RetinaNet Implementation in PyTorch\n", | |
"\n", | |
"Implementation of the following paper: [Focal Loss for Dense Object Detection](https://arxiv.org/pdf/1708.02002.pdf)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "cc7a980f-fddb-458f-b48e-fd0f12de3e58", | |
"metadata": {}, | |
"source": [ | |
"## Imports" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "8f124e21-1f07-4bfd-89e0-fdff09aa4e0d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"import math\n", | |
"import copy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "6e523509-e024-4df9-82eb-c64df1d60110", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from torch import nn, optim\n", | |
"from torch.nn import functional as F\n", | |
"from torch.utils.data import DataLoader" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "c3bfdb2b-b2a2-4570-93c2-9a9a89a5615e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torchvision\n", | |
"from torchvision import transforms, datasets\n", | |
"from torchvision.transforms import functional as FT\n", | |
"from torchvision.transforms import transforms as T" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "feb9f59e-1e07-4810-9408-58b23cf27c18", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from PIL import Image\n", | |
"import os\n", | |
"import cv2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "f13fa1a6-d707-4f58-9202-9f65fa67a41f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from tqdm.notebook import tqdm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "8318d1c0-75da-44ab-9f01-521d7b0fd738", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"('1.9.0', '0.10.0')" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.__version__, torchvision.__version__" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "bd6b3cf8-65e3-436d-bf20-c553dd592f32", | |
"metadata": {}, | |
"source": [ | |
"## Transforms and Utilities" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "df785a0e-d026-4e2d-b2d6-f95b81db6757", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Compose:\n", | |
" def __init__(self, transforms):\n", | |
" self.transforms = transforms\n", | |
"\n", | |
" def __call__(self, image, target):\n", | |
" for t in self.transforms:\n", | |
" image, target = t(image, target)\n", | |
" return image, target" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "871e66f6-bc56-4095-90f0-2a5f603fd29e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Normalizer(object):\n", | |
"\n", | |
" def __init__(self):\n", | |
" self.mean = [0.485, 0.456, 0.406]\n", | |
" self.std = [0.229, 0.224, 0.225]\n", | |
" self.normalize = T.Compose([T.Normalize(mean=self.mean, std=self.std)])\n", | |
"\n", | |
" def __call__(self, image, target):\n", | |
"\n", | |
" return self.normalize(image), target" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "1d953822-6c79-4511-bd1c-19a84633f703", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Resize(object):\n", | |
" def __init__(self, size=400):\n", | |
" self.size = size\n", | |
" def __call__(self, img, target):\n", | |
" size = self.size\n", | |
" boxes = [t['bbox'] for t in target]\n", | |
" w, h = img.size\n", | |
" if isinstance(size, int):\n", | |
" size_min = min(w,h)\n", | |
" size_max = max(w,h)\n", | |
" sw = sh = float(size) / size_min\n", | |
" if sw * size_max > 800:\n", | |
" sw = sh = float(800) / size_max\n", | |
" ow = int(w * sw + 0.5)\n", | |
" oh = int(h * sh + 0.5)\n", | |
" else:\n", | |
" ow, oh = size\n", | |
" sw = float(ow) / w\n", | |
" sh = float(oh) / h\n", | |
" boxes = (torch.FloatTensor(boxes)*torch.Tensor([sw,sh,sw,sh])).tolist()\n", | |
" for t in range(len(target)):\n", | |
" target[t]['bbox'] = boxes[t]\n", | |
" return img.resize((ow,oh), Image.BILINEAR), target" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "b3a9d7e4-9fee-4ed0-a205-892721488fb8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class ToTensor(nn.Module):\n", | |
" def forward(\n", | |
" self, image, target = None\n", | |
" ):\n", | |
" image = FT.pil_to_tensor(image)\n", | |
" image = FT.convert_image_dtype(image)\n", | |
" return image, target" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "246f31f3-e29a-4b23-9f14-7c15468f94e8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class PILToTensor(nn.Module):\n", | |
" def forward(\n", | |
" self, image, target = None\n", | |
" ):\n", | |
" image = FT.pil_to_tensor(image)\n", | |
" return image, target" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "0a2ecfe0-6bda-4ff0-8de9-d8f5cfacb5ba", | |
"metadata": {}, | |
"source": [ | |
"## Dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "bb23610c-6a83-4236-9cae-0cfbad97d9de", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#### COLAB LOADER ####\n", | |
"# !curl -L \"https://public.roboflow.com/ds/L6PD1uTSPF?key=Gq3tCeIqHA\" > roboflow.zip; unzip roboflow.zip; rm roboflow.zip\n", | |
"# Use for colab only #\n", | |
"######################" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "57619344-33c5-40fc-8d34-ef7529487d00", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from pycocotools.coco import COCO" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "c7bab354-9b31-497d-834c-1ebf5d0ecf9f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dataset_path = \"/Volumes/Samsung_T5/Documents/MachineLearning/machine_learning_notebooks/pytorch/aquarium-dataset/Aquarium Combined/\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "96370ec5-bc2d-475c-9b94-8a228cf3c982", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"loading annotations into memory...\n", | |
"Done (t=0.02s)\n", | |
"creating index...\n", | |
"index created!\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"8" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"coc = COCO(os.path.join(dataset_path, \"train\", \"_annotations.coco.json\"))\n", | |
"categories = coc.cats\n", | |
"n_classes = len(categories.keys())\n", | |
"n_classes" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "12a71e96-e1cf-41ea-b8d0-5ce1b14d3ff7", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def xyxy_2_xywh(boxes):\n", | |
" a = torch.FloatTensor(boxes[:,:2])\n", | |
" b = torch.FloatTensor(boxes[:,2:])\n", | |
" return torch.cat([(a+b)/2,b-a+1], 1)\n", | |
" \n", | |
"def xywh_2_xyxy(boxes):\n", | |
" a = torch.FloatTensor(boxes[:,:2])\n", | |
" b = torch.FloatTensor(boxes[:,2:])\n", | |
" return torch.cat([a-b/2,a+b/2], 1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "826dfb96-824a-44e5-9268-249c0756e7ae", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def box_nms(bboxes, scores, threshold=0.5, mode='union'):\n", | |
" \n", | |
" x1 = bboxes[:,0]\n", | |
" y1 = bboxes[:,1]\n", | |
" x2 = bboxes[:,2]\n", | |
" y2 = bboxes[:,3]\n", | |
"\n", | |
" areas = (x2-x1+1) * (y2-y1+1)\n", | |
" _, order = scores.sort(0, descending=True)\n", | |
"\n", | |
" keep = []\n", | |
" while order.numel() > 0:\n", | |
" if order.numel() == 1:\n", | |
" keep.append(order.item())\n", | |
" break\n", | |
" \n", | |
" i = order[0]\n", | |
" keep.append(i)\n", | |
"\n", | |
" xx1 = x1[order[1:]].clamp(min=x1[i])\n", | |
" yy1 = y1[order[1:]].clamp(min=y1[i])\n", | |
" xx2 = x2[order[1:]].clamp(max=x2[i])\n", | |
" yy2 = y2[order[1:]].clamp(max=y2[i])\n", | |
"\n", | |
" w = (xx2-xx1+1).clamp(min=0)\n", | |
" h = (yy2-yy1+1).clamp(min=0)\n", | |
" inter = w*h\n", | |
"\n", | |
" if mode == 'union':\n", | |
" ovr = inter / (areas[i] + areas[order[1:]] - inter)\n", | |
" elif mode == 'min':\n", | |
" ovr = inter / areas[order[1:]].clamp(max=areas[i])\n", | |
" else:\n", | |
" raise TypeError('Unknown nms mode: %s.' % mode)\n", | |
"\n", | |
" ids = (ovr<=threshold).nonzero().squeeze()\n", | |
" if ids.numel() == 0:\n", | |
" break\n", | |
" order = order[ids+1]\n", | |
" return torch.LongTensor(keep)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "88ac1227-a417-403a-8cac-6b04faadc851", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def iou(box1, box2, order=\"xyxy\"):\n", | |
" if order == \"xywh\":\n", | |
" box1 = xywh_2_xyxy(box1)\n", | |
" box2 = xywh_2_xyxy(box2)\n", | |
" N = box1.size(0)\n", | |
" M = box2.size(0)\n", | |
"\n", | |
" lt = torch.max(box1[:,None,:2], box2[:,:2]) # [N,M,2]\n", | |
" rb = torch.min(box1[:,None,2:], box2[:,2:]) # [N,M,2]\n", | |
"\n", | |
" wh = (rb-lt+1).clamp(min=0) # [N,M,2]\n", | |
" inter = wh[:,:,0] * wh[:,:,1] # [N,M]\n", | |
"\n", | |
" area1 = (box1[:,2]-box1[:,0]+1) * (box1[:,3]-box1[:,1]+1) # [N,]\n", | |
" area2 = (box2[:,2]-box2[:,0]+1) * (box2[:,3]-box2[:,1]+1) # [M,]\n", | |
" iou = inter / (area1[:,None] + area2 - inter)\n", | |
" return iou" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "f313d08d-9c3b-400b-9b8b-28cbd5ccdec9", | |
"metadata": {}, | |
"source": [ | |
"### Anchor Boxes\n", | |
"\n", | |
"\"*Anchor boxes have areas of $32^2$ to $512^2$ on pyramid levels $P_3$ to $P_7$.*\" (Page 4, Focal Loss for Dense Object Detection)\n", | |
"\n", | |
"- Aspect ratios: $\\{1:2, 1:1, 2:1\\}$, translates to `[0.5, 1, 2]` in python\n", | |
"- Scales: $\\{2^0, 2^{1/3}, 2^{2/3}\\}$\n", | |
"\n", | |
"There should be a total of $A=9$ anchors per level" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "359ef414-00f1-487d-9269-6d2316b10e80", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class AnchorBox():\n", | |
" \"\"\"\n", | |
" Generate anchor boxes for level 3 to level 8\n", | |
" \"\"\"\n", | |
" def __init__(self):\n", | |
" self.ratios = [0.5, 1, 2]\n", | |
" self.scales = [1, 2**(1/3), 2**(2/3)]\n", | |
" \n", | |
" self.A = len(self.ratios) * len(self.scales) # number of anchors (from paper)\n", | |
" self.areas = [x**2 for x in [32, 64, 128, 256, 512]] # P3, P4, P5, P6, P7\n", | |
" self.strides = [2 ** i for i in range(3, 8)] # Each layer's feature map is 2^l smaller than the input\n", | |
" self.anchor_dims = self._anchor_dims()\n", | |
" ## for feature map sizes\n", | |
" \n", | |
" def _meshgrid(self, x, y, row_major=True):\n", | |
" a = torch.arange(0,x)\n", | |
" b = torch.arange(0,y)\n", | |
" xx = a.repeat(y).view(-1,1)\n", | |
" yy = b.view(-1,1).repeat(1,x).view(-1,1)\n", | |
" return torch.cat([xx,yy],1) if row_major else torch.cat([yy,xx],1)\n", | |
" \n", | |
" def _anchor_dims(self):\n", | |
" anchor_dims = []\n", | |
" for area in self.areas:\n", | |
" for ratio in self.ratios:\n", | |
" anchor_height = math.sqrt(area / ratio)\n", | |
" anchor_width = area / anchor_height\n", | |
" \n", | |
" for scale in self.scales:\n", | |
" anchor_width = anchor_width * scale\n", | |
" anchor_height = anchor_height * scale\n", | |
" anchor_dims.append([anchor_width, anchor_height])\n", | |
" return torch.FloatTensor(anchor_dims).view(len(self.areas), -1, 2)\n", | |
" \n", | |
" def generate_anchor_boxes(self, input_size):\n", | |
" \"\"\"\n", | |
" Generates Anchor Boxes\n", | |
" \n", | |
" input_size: torch.Tensor: (w, h)\n", | |
" \"\"\"\n", | |
" \n", | |
" num_feature_maps = len(self.areas)\n", | |
" feature_map_sizes = [(input_size / stride).ceil() for stride in self.strides] # calculating feature map sizes of p3 to p7\n", | |
" boxes = []\n", | |
" for i in range(num_feature_maps):\n", | |
" fm_size = feature_map_sizes[i]\n", | |
" grid_size = input_size / fm_size\n", | |
" fm_w, fm_h = int(fm_size[0]), int(fm_size[1])\n", | |
" xy = self._meshgrid(fm_w,fm_h) + 0.5 # [fm_h*fm_w, 2]\n", | |
" xy = (xy*grid_size).view(fm_h,fm_w,1,2).expand(fm_h,fm_w,9,2)\n", | |
" wh = self.anchor_dims[i].view(1,1,9,2).expand(fm_h,fm_w,9,2)\n", | |
" box = torch.cat([xy,wh], 3) # [x,y,w,h]\n", | |
" boxes.append(box.view(-1,4))\n", | |
" return torch.cat(boxes, 0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "2d71e06b-3ce1-41f1-afae-034a62bc98b6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Encoder:\n", | |
" def __init__(self):\n", | |
" self.anchor_box = AnchorBox()\n", | |
" def encode(self, boxes, labels, input_size):\n", | |
" input_size = torch.Tensor([input_size,input_size]) if isinstance(input_size, int) \\\n", | |
" else torch.Tensor(input_size)\n", | |
" anchor_boxes = self.anchor_box.generate_anchor_boxes(input_size)\n", | |
"# boxes = xyxy_2_xywh(boxes)\n", | |
" boxes = torch.FloatTensor(boxes)\n", | |
" \n", | |
" ious = iou(anchor_boxes, boxes, order=\"xywh\")\n", | |
" max_ious, max_ids = ious.max(1)\n", | |
" boxes = boxes[max_ids]\n", | |
" \n", | |
" loc_xy = (boxes[:,:2]-anchor_boxes[:,:2]) / anchor_boxes[:,2:]\n", | |
" loc_wh = torch.log(boxes[:,2:]/anchor_boxes[:,2:])\n", | |
" loc_targets = torch.cat([loc_xy,loc_wh], 1)\n", | |
" cls_targets = 1 + labels[max_ids]\n", | |
"\n", | |
" cls_targets[max_ious<0.5] = 0\n", | |
" ignore = (max_ious>0.4) & (max_ious<0.5) # ignore ious between [0.4,0.5]\n", | |
" cls_targets[ignore] = -1 # for now just mark ignored to -1\n", | |
" return loc_targets, cls_targets\n", | |
" \n", | |
" def decode(self, loc_preds, cls_preds, input_size):\n", | |
" input_size = torch.Tensor([input_size,input_size]) if isinstance(input_size, int) else torch.Tensor(input_size)\n", | |
" \n", | |
" CLS_THRESH = 0.5\n", | |
" NMS_THRESH = 0.5\n", | |
" anchor_boxes = self.anchor_box.generate_anchor_boxes(input_size)\n", | |
"\n", | |
" loc_xy = loc_preds[:,:2]\n", | |
" loc_wh = loc_preds[:,2:]\n", | |
"\n", | |
" xy = loc_xy * anchor_boxes[:,2:] + anchor_boxes[:,:2]\n", | |
" wh = loc_wh.exp() * anchor_boxes[:,2:]\n", | |
" boxes = torch.cat([xy-wh/2, xy+wh/2], 1)\n", | |
"\n", | |
" score, labels = cls_preds.sigmoid().max(1)\n", | |
" ids = score > CLS_THRESH\n", | |
" ids = ids.nonzero().squeeze()\n", | |
" keep = box_nms(boxes[ids], score[ids], threshold=NMS_THRESH)\n", | |
" return boxes[ids][keep], labels[ids][keep]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 53, | |
"id": "91eb2fc1-4fe1-4558-8742-619b89286361", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class AquariumDetection(datasets.VisionDataset):\n", | |
" def __init__(\n", | |
" self,\n", | |
" root: str,\n", | |
" split = \"train\",\n", | |
" transform= None,\n", | |
" target_transform = None,\n", | |
" transforms = None,\n", | |
" ) -> None:\n", | |
" super().__init__(root, transforms, transform, target_transform)\n", | |
" self.split = split\n", | |
" self.coco = COCO(os.path.join(root, split, \"_annotations.coco.json\"))\n", | |
" self.ids = list(sorted(self.coco.imgs.keys()))\n", | |
" self.ids = [id for id in self.ids if (len(self._load_target(id)) > 0)]\n", | |
"\n", | |
" def _load_image(self, id: int) -> Image.Image:\n", | |
" path = self.coco.loadImgs(id)[0][\"file_name\"]\n", | |
" img = Image.open(os.path.join(self.root, self.split, path)).convert(\"RGB\")\n", | |
" return img\n", | |
"\n", | |
" def _load_target(self, id: int):\n", | |
" return self.coco.loadAnns(self.coco.getAnnIds(id))\n", | |
"\n", | |
" def __getitem__(self, index: int):\n", | |
" id = self.ids[index]\n", | |
" image = self._load_image(id)\n", | |
" target = copy.deepcopy(self._load_target(id))\n", | |
"\n", | |
" if self.transforms is not None:\n", | |
" image, target = self.transforms(image, target)\n", | |
" \n", | |
" annot = [t[\"bbox\"] + [t[\"category_id\"]] for t in target]\n", | |
"\n", | |
" return image, annot\n", | |
"\n", | |
"\n", | |
" def __len__(self) -> int:\n", | |
" return len(self.ids)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 54, | |
"id": "3701d805-c03a-410c-a51a-0f7305d1cf35", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def collate_fn(batch):\n", | |
" \"\"\"\n", | |
" The images in the dataset will be of different sizes. This function takes the images and pads them. Then we encode the images.\n", | |
" \"\"\"\n", | |
" imgs = [x[0] for x in batch]\n", | |
" annots = np.array([x[1] for x in batch], dtype=object)\n", | |
"\n", | |
" widths = [int(s.shape[1]) for s in imgs]\n", | |
" heights = [int(s.shape[2]) for s in imgs]\n", | |
" batch_size = len(imgs)\n", | |
"\n", | |
" max_width = np.array(widths).max()\n", | |
" max_height = np.array(heights).max()\n", | |
"\n", | |
" padded_imgs = torch.zeros(batch_size, max_width, max_height, 3)\n", | |
"\n", | |
" for i in range(batch_size):\n", | |
" img = imgs[i]\n", | |
" padded_imgs[i, :int(img.shape[1]), :int(img.shape[2]), :] = img.permute(1, 2, 0)\n", | |
" padded_imgs = padded_imgs.permute(0, 3, 1, 2)\n", | |
" \n", | |
" ## Encode ##\n", | |
" encoder = Encoder()\n", | |
" loc_targets = []\n", | |
" cls_targets = []\n", | |
" for i in range(len(imgs)):\n", | |
" annot = annots[i]\n", | |
" boxes = np.array(annot)[:, 0:4]\n", | |
" labels = np.array(annot)[:, 4]\n", | |
" image = padded_imgs[i]\n", | |
" loc_target, cls_target = encoder.encode(boxes, labels, (image.shape[1], image.shape[2]))\n", | |
" loc_targets.append(torch.FloatTensor(loc_target))\n", | |
" cls_targets.append(torch.FloatTensor(cls_target))\n", | |
" return {'img': padded_imgs, 'loc_targets': torch.stack(loc_targets), 'cls_targets': torch.stack(cls_targets)}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 55, | |
"id": "351a5ec0-56e9-4c56-a1b0-7adb6543e511", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_transform(train):\n", | |
" transforms = []\n", | |
" transforms.append(Resize(size=300))\n", | |
" transforms.append(ToTensor())\n", | |
" transforms.append(Normalizer())\n", | |
" return Compose(transforms)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 56, | |
"id": "c693c3ab-0176-4c06-b2a7-309d24d8bb31", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"loading annotations into memory...\n", | |
"Done (t=0.02s)\n", | |
"creating index...\n", | |
"index created!\n", | |
"loading annotations into memory...\n", | |
"Done (t=0.00s)\n", | |
"creating index...\n", | |
"index created!\n", | |
"loading annotations into memory...\n", | |
"Done (t=0.00s)\n", | |
"creating index...\n", | |
"index created!\n" | |
] | |
} | |
], | |
"source": [ | |
"train_dataset = AquariumDetection(root=dataset_path, transforms=get_transform(True))\n", | |
"val_dataset = AquariumDetection(root=dataset_path, split=\"valid\", transforms=get_transform(False))\n", | |
"test_dataset = AquariumDetection(root=dataset_path, split=\"test\", transforms=get_transform(False))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 57, | |
"id": "6a890ca3-b3a6-4f6a-a4f3-6fdf3a518cb4", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"train_loader = DataLoader(train_dataset, batch_size=8, collate_fn=collate_fn)\n", | |
"val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=collate_fn)\n", | |
"test_loader = DataLoader(test_dataset, batch_size=8, collate_fn=collate_fn)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 60, | |
"id": "a5e86435-ff29-4cbe-bc47-078dc15c8de7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(56, 16, 8)" | |
] | |
}, | |
"execution_count": 60, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"len(train_loader), len(val_loader), len(test_loader)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 59, | |
"id": "d9332820-f53f-4a3a-b6cc-aa8c0df89596", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"for i in range(len(train_dataset)):\n", | |
" _ = train_dataset[i]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "d4cf6818-a90d-435c-b88c-8c6b07f03248", | |
"metadata": { | |
"tags": [] | |
}, | |
"source": [ | |
"## Retinanet Implementation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 167, | |
"id": "ad44cca1-4e59-4058-9ddf-e9743650fb58", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n", | |
" \"\"\"3x3 convolution with padding\"\"\"\n", | |
" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n", | |
" padding=dilation, groups=groups, bias=False, dilation=dilation)\n", | |
"\n", | |
"\n", | |
"def conv1x1(in_planes, out_planes, stride=1):\n", | |
" \"\"\"1x1 convolution\"\"\"\n", | |
" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 168, | |
"id": "cb82dba8-cc96-49f0-aba1-fbf83775b8e8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Bottleneck(nn.Module):\n", | |
" expansion = 4\n", | |
" def __init__(self, inplanes, planes, stride=1, groups=1,\n", | |
" base_width=64, dilation=1):\n", | |
" super(Bottleneck, self).__init__()\n", | |
" norm_layer = nn.BatchNorm2d\n", | |
" width = int(planes * (base_width / 64.)) * groups\n", | |
" self.conv1 = conv1x1(inplanes, width)\n", | |
" self.bn1 = norm_layer(width)\n", | |
" self.conv2 = conv3x3(width, width, stride, groups, dilation)\n", | |
" self.bn2 = norm_layer(width)\n", | |
" self.conv3 = conv1x1(width, planes * self.expansion)\n", | |
" self.bn3 = norm_layer(planes * self.expansion)\n", | |
" self.relu = nn.ReLU(inplace=True)\n", | |
" \n", | |
" self.downsample = nn.Sequential()\n", | |
" if stride != 1 or inplanes != self.expansion*planes:\n", | |
" self.downsample = nn.Sequential(\n", | |
" nn.Conv2d(inplanes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),\n", | |
" nn.BatchNorm2d(self.expansion*planes)\n", | |
" )\n", | |
" \n", | |
" self.stride = stride\n", | |
"\n", | |
" def forward(self, x):\n", | |
" identity = x\n", | |
"\n", | |
" out = self.conv1(x)\n", | |
" out = self.bn1(out)\n", | |
" out = self.relu(out)\n", | |
"\n", | |
" out = self.conv2(out)\n", | |
" out = self.bn2(out)\n", | |
" out = self.relu(out)\n", | |
"\n", | |
" out = self.conv3(out)\n", | |
" out = self.bn3(out)\n", | |
"\n", | |
" identity = self.downsample(x)\n", | |
"\n", | |
" out += identity\n", | |
" out = self.relu(out)\n", | |
"\n", | |
" return out\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 169, | |
"id": "5ca298d4-9375-41c1-b66f-8d60640b5581", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class FPN(nn.Module):\n", | |
" def __init__(self, block, num_blocks):\n", | |
" super(FPN, self).__init__()\n", | |
" self.in_planes = 64\n", | |
" \n", | |
" self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)\n", | |
" self.bn1 = nn.BatchNorm2d(64)\n", | |
" \n", | |
" self.conv2 = self._make_layer(block, 64, num_blocks=num_blocks[0], stride=1)\n", | |
" self.conv3 = self._make_layer(block, 128, num_blocks=num_blocks[1], stride=2)\n", | |
" self.conv4 = self._make_layer(block, 256, num_blocks=num_blocks[2], stride=2)\n", | |
" self.conv5 = self._make_layer(block, 512, num_blocks=num_blocks[3], stride=2)\n", | |
" \n", | |
" self.conv6 = nn.Conv2d(2048, 256, kernel_size=3, stride=2, padding=1)\n", | |
" self.conv7 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)\n", | |
" \n", | |
" ## lateral layers ##\n", | |
" self.lat1 = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0)\n", | |
" self.lat2 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)\n", | |
" self.lat3 = nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0)\n", | |
" \n", | |
" ## top-down layers ##\n", | |
" self.topdown1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)\n", | |
" self.topdown2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)\n", | |
" \n", | |
" self.relu = nn.ReLU()\n", | |
" \n", | |
" def _upsample_and_add(self, x, y):\n", | |
" _,_,H,W = y.size()\n", | |
" return F.upsample(x, size=(H,W), mode='bilinear') + y\n", | |
" \n", | |
" def _make_layer(self, block, planes, num_blocks, stride):\n", | |
" strides = [stride] + [1]*(num_blocks-1)\n", | |
" layers = []\n", | |
" for stride in strides:\n", | |
" layers.append(block(self.in_planes, planes, stride))\n", | |
" self.in_planes = planes * block.expansion\n", | |
" return nn.Sequential(*layers)\n", | |
" \n", | |
" def forward(self, x):\n", | |
" #bottom up\n", | |
" c1 = self.relu(self.bn1(self.conv1(x)))\n", | |
" c1 = F.max_pool2d(c1, kernel_size=3, stride=2, padding=1)\n", | |
" c2 = self.conv2(c1)\n", | |
" c3 = self.conv3(c2)\n", | |
" c4 = self.conv4(c3)\n", | |
" c5 = self.conv5(c4)\n", | |
" p6 = self.conv6(c5)\n", | |
" p7 = self.conv7(p6)\n", | |
" p5 = self.lat1(c5)\n", | |
" p4 = self._upsample_and_add(p5, self.lat2(c4))\n", | |
" p4 = self.topdown1(p4)\n", | |
" p3 = self._upsample_and_add(p4, self.lat3(c3))\n", | |
" p3 = self.topdown2(p3)\n", | |
" return p3, p4, p5, p6, p7" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 170, | |
"id": "5d228fae-537c-4ac6-a3b5-1e838112eb3f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class ClassificationHead(nn.Module):\n", | |
" def __init__(self, n_classes=8):\n", | |
" super(ClassificationHead, self).__init__()\n", | |
" self.n_anchors = 9\n", | |
" self.n_classes = n_classes\n", | |
" \n", | |
" self.convnet = nn.Sequential(*[\n", | |
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n", | |
" nn.ReLU(True),\n", | |
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n", | |
" nn.ReLU(True),\n", | |
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n", | |
" nn.ReLU(True),\n", | |
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n", | |
" nn.ReLU(True),\n", | |
" nn.Conv2d(256, self.n_classes*self.n_anchors, kernel_size=3, stride=1, padding=1) # KA\n", | |
" ])\n", | |
" def forward(self, x):\n", | |
" return self.convnet(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 171, | |
"id": "4c57cc78-6615-45a3-93ce-242157536e0f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class RegressionHead(nn.Module):\n", | |
" def __init__(self):\n", | |
" super(RegressionHead, self).__init__()\n", | |
" self.n_anchors = 9\n", | |
" self.convnet = nn.Sequential(*[\n", | |
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n", | |
" nn.ReLU(True),\n", | |
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n", | |
" nn.ReLU(True),\n", | |
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n", | |
" nn.ReLU(True),\n", | |
" nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n", | |
" nn.ReLU(True),\n", | |
" nn.Conv2d(256, 4*self.n_anchors, kernel_size=3, stride=1, padding=1) # KA\n", | |
" ])\n", | |
" def forward(self, x):\n", | |
" return self.convnet(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 172, | |
"id": "86f363cc-7c90-4be0-8522-89cc96d087a2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class RetinaNet(nn.Module):\n", | |
" def __init__(self, n_classes=8):\n", | |
" super(RetinaNet, self).__init__()\n", | |
" \n", | |
" self.fpn = FPN(Bottleneck, [3, 4, 6, 3])\n", | |
" \n", | |
" self.num_classes = n_classes\n", | |
" \n", | |
" self.classification_head = ClassificationHead(n_classes = self.num_classes) # class head\n", | |
" self.regression_head = RegressionHead() # loc head\n", | |
" def forward(self, x):\n", | |
" feature_maps = self.fpn(x) #p3, p4, p5, p6, p7\n", | |
" \n", | |
" loc_preds = []\n", | |
" cls_preds = []\n", | |
" \n", | |
" for fmap in feature_maps:\n", | |
" loc_pred = self.regression_head(fmap)\n", | |
" cls_pred = self.classification_head(fmap)\n", | |
" \n", | |
" loc_pred = loc_pred.permute(0,2,3,1).contiguous().view(x.size(0),-1,4) \n", | |
" cls_pred = cls_pred.permute(0,2,3,1).contiguous().view(x.size(0),-1,self.num_classes) \n", | |
" \n", | |
" loc_preds.append(loc_pred)\n", | |
" cls_preds.append(cls_pred)\n", | |
" \n", | |
" return torch.cat(loc_preds, 1), torch.cat(cls_preds, 1)\n", | |
" \n", | |
" def freeze_bn(self):\n", | |
" '''Freeze BatchNorm layers.'''\n", | |
" for layer in self.modules():\n", | |
" if isinstance(layer, nn.BatchNorm2d):\n", | |
" layer.eval()\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 173, | |
"id": "767203b7-fbdb-4708-a942-42022b019bc8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"net = RetinaNet()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 176, | |
"id": "233b4698-e07b-4430-9c88-e3b1f1106e75", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(torch.Size([8, 30231, 4]), torch.Size([8, 30231, 8]))" | |
] | |
}, | |
"execution_count": 176, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"batch = next(iter(train_loader))\n", | |
"loc_preds, cls_preds = net(batch['img'])\n", | |
"loc_preds.shape, cls_preds.shape # The 2nd number should be the same as the number of anchors per image." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 179, | |
"id": "6abb61c1-03e4-4530-b56b-e401965c5917", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"encoder = Encoder()\n", | |
"# _ = encoder.decode(loc_preds[0], cls_preds[0], tuple(batch['img'].shape[2:]))\n", | |
"## Ensure this cell just runs ##" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "414de16f-fee8-4d7e-87c3-2f9484911786", | |
"metadata": {}, | |
"source": [ | |
"## Focal Loss\n", | |
"\n", | |
"An extension of Cross Entropy\n", | |
"$$\n", | |
"FL(p_t) = -\\alpha(1-p_t)^{\\gamma}log(p_t)\n", | |
"$$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 180, | |
"id": "ab44b1d4-fd07-4334-b84e-adb1bd4b9ee0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def one_hot_embedding(labels, num_classes):\n", | |
" y = torch.eye(num_classes)\n", | |
" return y[labels]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 222, | |
"id": "fc16aa9b-f88e-461b-bdb6-3a39202a4330", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class FocalLoss(nn.Module):\n", | |
" def __init__(self, n_classes = 8):\n", | |
" super().__init__()\n", | |
" self.n_classes = n_classes\n", | |
" \n", | |
" def focal_loss(self, x, y):\n", | |
" alpha = -0.25\n", | |
" gamma = 2 # Paper recommended values\n", | |
" \n", | |
" t = one_hot_embedding(y.cpu(), 1 + self.n_classes)\n", | |
" t = t[:,1:]\n", | |
" if torch.cuda.is_available():\n", | |
" t = t.cuda()\n", | |
" \n", | |
" xt = x*(2*t-1)\n", | |
" pt = (2*xt+1).sigmoid()\n", | |
" \n", | |
" w = alpha*t + (1-alpha)*(1-t)\n", | |
" loss = -w*pt.log() / 2\n", | |
" return loss.sum()\n", | |
" \n", | |
" def forward(self, loc_preds, loc_true, cls_preds, cls_true):\n", | |
" batch_size, num_boxes = cls_true.size()\n", | |
" pos = cls_true > 0\n", | |
" num_pos = pos.long().sum()\n", | |
" \n", | |
" ## Loc loss\n", | |
" mask = pos.unsqueeze(2).expand_as(loc_preds) # [N,#anchors,4]\n", | |
" masked_loc_preds = loc_preds[mask].view(-1,4) # [#pos,4]\n", | |
" masked_loc_true = loc_true[mask].view(-1,4) # [#pos,4]\n", | |
" loc_loss = F.smooth_l1_loss(masked_loc_preds, masked_loc_true, size_average=False)\n", | |
" ## cls loss\n", | |
" pos_neg = cls_true > -1 # exclude ignored anchors\n", | |
" mask = pos_neg.unsqueeze(2).expand_as(cls_preds)\n", | |
" masked_cls_preds = cls_preds[mask].view(-1,self.n_classes)\n", | |
" cls_loss = self.focal_loss(masked_cls_preds, cls_true[pos_neg])\n", | |
" \n", | |
" loss = (loc_loss + cls_loss) / num_pos\n", | |
" return loss" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "90160ab7-8e15-4306-9280-798d43165794", | |
"metadata": {}, | |
"source": [ | |
"## Initialization" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 223, | |
"id": "2e63b8db-31f2-48c0-9f18-8594e3f82d6e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"criterion = FocalLoss()\n", | |
"optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 224, | |
"id": "d6b6b57c-68ac-40d9-91f3-a15e7e9f1af7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"device(type='cpu')" | |
] | |
}, | |
"execution_count": 224, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | |
"device" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 225, | |
"id": "c8b82edb-389f-471d-a05a-c2ac3890acb8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"net = net.to(device)\n", | |
"criterion = criterion.to(device)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b801f336-8c55-4217-8e67-221970c55f00", | |
"metadata": {}, | |
"source": [ | |
"## Training" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 234, | |
"id": "9497a504-1e21-4c3f-8339-7ad1afc086f8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def train(epoch):\n", | |
" net.train()\n", | |
" net.freeze_bn()\n", | |
" train_loss = 0\n", | |
" for batch in tqdm(train_loader):\n", | |
" imgs = batch['img'].to(device)\n", | |
" loc_targets = batch['loc_targets'].to(device)\n", | |
" cls_targets = batch['cls_targets'].to(device)\n", | |
" cls_targets = cls_targets.long()\n", | |
" \n", | |
" optimizer.zero_grad()\n", | |
" loc_pred, cls_pred = net(imgs)\n", | |
" loss = criterion(loc_pred, loc_targets, cls_pred, cls_targets)\n", | |
" \n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" \n", | |
" train_loss += loss.item()\n", | |
" print('train_loss: %.3f | avg_loss: %.3f' % (loss.data[0], train_loss/(len(train_loader))))\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 235, | |
"id": "c5bc39c8-431e-4fe1-ae38-03d484e7f952", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def test(epoch, loader):\n", | |
" with torch.no_grad():\n", | |
" net.eval()\n", | |
" test_loss = 0\n", | |
" for batch in tqdm(loader):\n", | |
" imgs = batch['img'].to(device)\n", | |
" loc_targets = batch['loc_targets'].to(device)\n", | |
" cls_targets = batch['cls_targets'].to(device)\n", | |
"\n", | |
" loc_pred, cls_pred = net(imgs)\n", | |
" loss = criterion(loc_pred, loc_targets, cls_pred, cls_targets)\n", | |
" test_loss += loss[0]\n", | |
" print('test_loss: %.3f | avg_loss: %.3f' % (loss.data[0], train_loss/(len(train_loader))))\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 236, | |
"id": "503f99ad-fed4-48af-ae08-7a430b9c8ae5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"EPOCHS = 50" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 237, | |
"id": "c7534024-d9a2-430a-8ddc-a6db0ad0bd57", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"EPOCH 1\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "ec194265c36b440383e074ffb997acce", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/56 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"[E thread_pool.cpp:112] Exception in thread pool task: mutex lock failed: Invalid argument\n", | |
"[E thread_pool.cpp:112] Exception in thread pool task: mutex lock failed: Invalid argument\n", | |
"[E thread_pool.cpp:112] Exception in thread pool task: mutex lock failed: Invalid argument\n", | |
"[E thread_pool.cpp:112] Exception in thread pool task: mutex lock failed: Invalid argument\n", | |
"[E thread_pool.cpp:112] Exception in thread pool task: mutex lock failed: Invalid argument\n", | |
"[E thread_pool.cpp:112] Exception in thread pool task: mutex lock failed: Invalid argument\n" | |
] | |
}, | |
{ | |
"ename": "KeyboardInterrupt", | |
"evalue": "", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m/var/folders/vr/x7p4fznn1dv39r_83dmyjkjm0000gn/T/ipykernel_42021/107800544.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"EPOCH {epoch}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m##\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mtest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;31m##\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/var/folders/vr/x7p4fznn1dv39r_83dmyjkjm0000gn/T/ipykernel_42021/3530921501.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(epoch)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloc_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloc_targets\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcls_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcls_targets\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/lib/python3.9/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 253\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 254\u001b[0m inputs=inputs)\n\u001b[0;32m--> 255\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 256\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 257\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/lib/python3.9/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[0mretain_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 147\u001b[0;31m Variable._execution_engine.run_backward(\n\u001b[0m\u001b[1;32m 148\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag\n", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m: " | |
] | |
} | |
], | |
"source": [ | |
"for epoch in range(1, EPOCHS + 1):\n", | |
" print(f\"EPOCH {epoch}\")\n", | |
" ##\n", | |
" train(epoch)\n", | |
" test(epoch, val_loader)\n", | |
" ##" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "79d32e91-d705-4711-8a71-79d94cd5f620", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "f7206910-c8cc-4909-84c3-76a7f3ee023e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e33c98cd-3787-45d6-86eb-cf7a65aaf882", | |
"metadata": {}, | |
"source": [ | |
"## Testing" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "e2d3a81a-43d3-4da1-8f7c-db98adeff2d5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "8efa5f7e-a5c9-4942-bde2-7e0897f2e036", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "69670f59-7b7a-46a2-abf7-80180d8f994e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "afd407d5-6d0a-46bc-b769-7b7e604accfb", | |
"metadata": {}, | |
"source": [ | |
"## Saving Model Weights" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "5a8ef18e-aa6b-4cdb-8ef5-c91671b76e9e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"torch.save(model.state_dict(), \"retinanet.pt\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "90ff717d-7d5c-413d-9ae8-c57b935f26f9", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.4" | |
}, | |
"toc-autonumbering": true | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment