Created
November 1, 2020 12:50
-
-
Save nikogamulin/7774e0e3988305a78fd73e1c4364aded to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"torch.Size([4, 1000])\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"\n", | |
"\n", | |
"class Block(nn.Module):\n", | |
" def __init__(self, num_layers, in_channels, out_channels, identity_downsample=None, stride=1):\n", | |
" assert num_layers in [18, 34, 50, 101, 152], \"should be a a valid architecture\"\n", | |
" super(Block, self).__init__()\n", | |
" self.num_layers = num_layers\n", | |
" if self.num_layers > 34:\n", | |
" self.expansion = 4\n", | |
" else:\n", | |
" self.expansion = 1\n", | |
" # ResNet50, 101, and 152 include additional layer of 1x1 kernels\n", | |
" self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)\n", | |
" self.bn1 = nn.BatchNorm2d(out_channels)\n", | |
" if self.num_layers > 34:\n", | |
" self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)\n", | |
" else:\n", | |
" # for ResNet18 and 34, connect input directly to (3x3) kernel (skip first (1x1))\n", | |
" self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)\n", | |
" self.bn2 = nn.BatchNorm2d(out_channels)\n", | |
" self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, stride=1, padding=0)\n", | |
" self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)\n", | |
" self.relu = nn.ReLU()\n", | |
" self.identity_downsample = identity_downsample\n", | |
"\n", | |
" def forward(self, x):\n", | |
" identity = x\n", | |
" if self.num_layers > 34:\n", | |
" x = self.conv1(x)\n", | |
" x = self.bn1(x)\n", | |
" x = self.relu(x)\n", | |
" x = self.conv2(x)\n", | |
" x = self.bn2(x)\n", | |
" x = self.relu(x)\n", | |
" x = self.conv3(x)\n", | |
" x = self.bn3(x)\n", | |
"\n", | |
" if self.identity_downsample is not None:\n", | |
" identity = self.identity_downsample(identity)\n", | |
"\n", | |
" x += identity\n", | |
" x = self.relu(x)\n", | |
" return x\n", | |
"\n", | |
"\n", | |
"class ResNet(nn.Module):\n", | |
" def __init__(self, num_layers, block, image_channels, num_classes):\n", | |
" assert num_layers in [18, 34, 50, 101, 152], f'ResNet{num_layers}: Unknown architecture! Number of layers has ' \\\n", | |
" f'to be 18, 34, 50, 101, or 152 '\n", | |
" super(ResNet, self).__init__()\n", | |
" if num_layers < 50:\n", | |
" self.expansion = 1\n", | |
" else:\n", | |
" self.expansion = 4\n", | |
" if num_layers == 18:\n", | |
" layers = [2, 2, 2, 2]\n", | |
" elif num_layers == 34 or num_layers == 50:\n", | |
" layers = [3, 4, 6, 3]\n", | |
" elif num_layers == 101:\n", | |
" layers = [3, 4, 23, 3]\n", | |
" else:\n", | |
" layers = [3, 8, 36, 3]\n", | |
" self.in_channels = 64\n", | |
" self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)\n", | |
" self.bn1 = nn.BatchNorm2d(64)\n", | |
" self.relu = nn.ReLU()\n", | |
" self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n", | |
"\n", | |
" # ResNetLayers\n", | |
" self.layer1 = self.make_layers(num_layers, block, layers[0], intermediate_channels=64, stride=1)\n", | |
" self.layer2 = self.make_layers(num_layers, block, layers[1], intermediate_channels=128, stride=2)\n", | |
" self.layer3 = self.make_layers(num_layers, block, layers[2], intermediate_channels=256, stride=2)\n", | |
" self.layer4 = self.make_layers(num_layers, block, layers[3], intermediate_channels=512, stride=2)\n", | |
"\n", | |
" self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n", | |
" self.fc = nn.Linear(512 * self.expansion, num_classes)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = self.conv1(x)\n", | |
" x = self.bn1(x)\n", | |
" x = self.relu(x)\n", | |
" x = self.maxpool(x)\n", | |
"\n", | |
" x = self.layer1(x)\n", | |
" x = self.layer2(x)\n", | |
" x = self.layer3(x)\n", | |
" x = self.layer4(x)\n", | |
"\n", | |
" x = self.avgpool(x)\n", | |
" x = x.reshape(x.shape[0], -1)\n", | |
" x = self.fc(x)\n", | |
" return x\n", | |
"\n", | |
" def make_layers(self, num_layers, block, num_residual_blocks, intermediate_channels, stride):\n", | |
" layers = []\n", | |
"\n", | |
" identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, intermediate_channels*self.expansion, kernel_size=1, stride=stride),\n", | |
" nn.BatchNorm2d(intermediate_channels*self.expansion))\n", | |
" layers.append(block(num_layers, self.in_channels, intermediate_channels, identity_downsample, stride))\n", | |
" self.in_channels = intermediate_channels * self.expansion # 256\n", | |
" for i in range(num_residual_blocks - 1):\n", | |
" layers.append(block(num_layers, self.in_channels, intermediate_channels)) # 256 -> 64, 64*4 (256) again\n", | |
" return nn.Sequential(*layers)\n", | |
"\n", | |
"\n", | |
"def ResNet18(img_channels=3, num_classes=1000):\n", | |
" return ResNet(18, Block, img_channels, num_classes)\n", | |
"\n", | |
"\n", | |
"def ResNet34(img_channels=3, num_classes=1000):\n", | |
" return ResNet(34, Block, img_channels, num_classes)\n", | |
"\n", | |
"\n", | |
"def ResNet50(img_channels=3, num_classes=1000):\n", | |
" return ResNet(50, Block, img_channels, num_classes)\n", | |
"\n", | |
"\n", | |
"def ResNet101(img_channels=3, num_classes=1000):\n", | |
" return ResNet(101, Block, img_channels, num_classes)\n", | |
"\n", | |
"\n", | |
"def ResNet152(img_channels=3, num_classes=1000):\n", | |
" return ResNet(152, Block, img_channels, num_classes)\n", | |
"\n", | |
"\n", | |
"def test():\n", | |
" net = ResNet18(img_channels=3, num_classes=1000)\n", | |
" y = net(torch.randn(4, 3, 224, 224)).to(\"cuda\")\n", | |
" print(y.size())\n", | |
"\n", | |
"\n", | |
"test()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"gist": { | |
"data": { | |
"description": "ResNet.ipynb", | |
"public": true | |
}, | |
"id": "" | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, thanks for the code. In the Block class for conv3 won't the filter size be 3 for resnet18 and resnet34?