Created
December 28, 2018 18:47
-
-
Save bearpelican/065ae1b403c3ca9b2a873835f6c2b69d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "import torch.functional as F\n", | |
| "torch.backends.cudnn.benchmark = True\n", | |
| "from functools import partial\n", | |
| "import functools\n", | |
| "import torchvision" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(1., device='cuda:0')" | |
| ] | |
| }, | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.ones([]).cuda() # warmup cuda" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Results\n", | |
| "\n", | |
| "resnet - 2x performance \n", | |
| "unet (upsample+conv) - 40% slower \n", | |
| "unet (conv-transpose) - 1.5x faster " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Unet Arch" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class UnetGenerator(nn.Module):\n", | |
| " def __init__(self, input_nc, output_nc, num_downs, ngf=64, up_layer='upsample'):\n", | |
| " super(UnetGenerator, self).__init__()\n", | |
| "\n", | |
| " # construct unet structure\n", | |
| " unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, innermost=True, up_layer=up_layer)\n", | |
| " for i in range(num_downs - 5): \n", | |
| " unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, up_layer=up_layer)\n", | |
| " unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, up_layer=up_layer)\n", | |
| " unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, up_layer=up_layer)\n", | |
| " unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, up_layer=up_layer)\n", | |
| " unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, up_layer=up_layer)\n", | |
| "\n", | |
| " self.model = unet_block\n", | |
| "\n", | |
| " def forward(self, input):\n", | |
| " return self.model(input)\n", | |
| "\n", | |
| "\n", | |
| "# Defines the submodule with skip connection.\n", | |
| "# X -------------------identity---------------------- X\n", | |
| "# |-- downsampling -- |submodule| -- upsampling --|\n", | |
| "class UnetSkipConnectionBlock(nn.Module):\n", | |
| " def __init__(self, outer_nc, inner_nc, input_nc=None,\n", | |
| " submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, up_layer='upsample'):\n", | |
| " super(UnetSkipConnectionBlock, self).__init__()\n", | |
| " self.outermost = outermost\n", | |
| " if input_nc is None:\n", | |
| " input_nc = outer_nc\n", | |
| " \n", | |
| " downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1)\n", | |
| " downrelu = nn.LeakyReLU(0.2, True)\n", | |
| " downnorm = norm_layer(inner_nc)\n", | |
| "\n", | |
| " up_inner_nc = inner_nc if innermost else inner_nc * 2\n", | |
| " up_conv_layer = nn.Conv2d(up_inner_nc, outer_nc, kernel_size=3, padding=1)\n", | |
| " if up_layer == 'upsample':\n", | |
| " upconv = [nn.UpsamplingNearest2d(scale_factor=2), up_conv_layer]\n", | |
| " elif up_layer == 'shuffle':\n", | |
| " wide_layer = nn.Conv2d(up_inner_nc, up_inner_nc*4,\n", | |
| " kernel_size=1, stride=1,\n", | |
| " padding=0)\n", | |
| " upconv = [wide_layer, nn.LeakyReLU(0.2, True), nn.PixelShuffle(2), up_conv_layer]\n", | |
| " elif up_layer == 'transpose':\n", | |
| " upconv = [nn.ConvTranspose2d(up_inner_nc, outer_nc,\n", | |
| " kernel_size=4, stride=2,\n", | |
| " padding=1)]\n", | |
| " uprelu = nn.ReLU(True)\n", | |
| " upnorm = norm_layer(outer_nc)\n", | |
| "\n", | |
| " if outermost:\n", | |
| " down = [downconv]\n", | |
| " up = [uprelu] + upconv + [nn.Tanh()]\n", | |
| " model = down + [submodule] + up\n", | |
| " elif innermost:\n", | |
| " down = [downrelu, downconv]\n", | |
| " up = [uprelu] + upconv + [upnorm]\n", | |
| " model = down + up\n", | |
| " else:\n", | |
| " down = [downrelu, downconv, downnorm]\n", | |
| " up = [uprelu] + upconv + [upnorm]\n", | |
| " model = down + [submodule] + up\n", | |
| "\n", | |
| " self.model = nn.Sequential(*model)\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " if self.outermost:\n", | |
| " return self.model(x)\n", | |
| " else:\n", | |
| " return torch.cat([x, self.model(x)], 1)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Train" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def train(m, it, half, shape):\n", | |
| " x = torch.randn(shape).cuda()\n", | |
| " m = m.cuda()\n", | |
| " \n", | |
| " if half:\n", | |
| " x = x.half()\n", | |
| " m = m.half()\n", | |
| " \n", | |
| " for i in range(it):\n", | |
| " out = m(x)\n", | |
| " loss = out.sum()\n", | |
| " loss.backward()\n", | |
| " if hasattr(m, 'zero_grad'):\n", | |
| " m.zero_grad()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Model - resnet" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "shape = [32,3,224,224]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "m = torchvision.models.resnet50()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 9.12 s, sys: 4.59 s, total: 13.7 s\n", | |
| "Wall time: 15.2 s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%time train(m, it=100, half=False, shape=shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Half" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "m = torchvision.models.resnet50()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 4.03 s, sys: 1.71 s, total: 5.74 s\n", | |
| "Wall time: 7.14 s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%time train(m, it=100, half=True, shape=shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Model - upsample" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "shape = [32,3,256,256]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "m = UnetGenerator(shape[1],shape[1],7,up_layer='upsample')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 6.38 s, sys: 5.25 s, total: 11.6 s\n", | |
| "Wall time: 15.4 s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%time train(m, it=50, half=False, shape=shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Half" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "m = UnetGenerator(shape[1],shape[1],7,up_layer='upsample')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 10.9 s, sys: 8.6 s, total: 19.5 s\n", | |
| "Wall time: 24.7 s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%time train(m, it=50, half=True, shape=shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Model - transpose" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "shape = [32,3,256,256]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "m = UnetGenerator(shape[1],shape[1],7,up_layer='transpose')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 3.52 s, sys: 2.62 s, total: 6.14 s\n", | |
| "Wall time: 6.81 s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%time train(m, it=50, half=False, shape=shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Half" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "m = UnetGenerator(shape[1],shape[1],7,up_layer='transpose')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 2.42 s, sys: 1.65 s, total: 4.07 s\n", | |
| "Wall time: 4.81 s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%time train(m, it=50, half=True, shape=shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Model - shuffle" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "shape = [32,3,256,256]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "m = UnetGenerator(shape[1],shape[1],7,up_layer='shuffle')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 7.23 s, sys: 5.96 s, total: 13.2 s\n", | |
| "Wall time: 21.1 s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%time train(m, it=50, half=False, shape=shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Half" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "m = UnetGenerator(shape[1],shape[1],7,up_layer='shuffle')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 4.12 s, sys: 4.16 s, total: 8.28 s\n", | |
| "Wall time: 17.9 s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%time train(m, it=50, half=True, shape=shape)" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.7.1" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment