Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save bearpelican/33828d56f4471ab034ab33114f2e7517 to your computer and use it in GitHub Desktop.

Select an option

Save bearpelican/33828d56f4471ab034ab33114f2e7517 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.functional as F\n",
"torch.backends.cudnn.benchmark = True\n",
"from functools import partial\n",
"import functools"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(1., device='cuda:0')"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.ones([]).cuda() # warmup cuda"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Results\n",
"\n",
"resnet - 2x performance \n",
"unet (upsample+conv) - 40% slower \n",
"unet (conv-transpose) - 1.5x faster "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Unet Arch"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class UnetGenerator(nn.Module):\n",
" def __init__(self, input_nc, output_nc, num_downs, ngf=64, up_layer='upsample'):\n",
" super(UnetGenerator, self).__init__()\n",
"\n",
" # construct unet structure\n",
" unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, innermost=True, up_layer=up_layer)\n",
" for i in range(num_downs - 5): \n",
" unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, up_layer=up_layer)\n",
" unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, up_layer=up_layer)\n",
" unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, up_layer=up_layer)\n",
" unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, up_layer=up_layer)\n",
" unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, up_layer=up_layer)\n",
"\n",
" self.model = unet_block\n",
"\n",
" def forward(self, input):\n",
" return self.model(input)\n",
"\n",
"\n",
"# Defines the submodule with skip connection.\n",
"# X -------------------identity---------------------- X\n",
"# |-- downsampling -- |submodule| -- upsampling --|\n",
"class UnetSkipConnectionBlock(nn.Module):\n",
" def __init__(self, outer_nc, inner_nc, input_nc=None,\n",
" submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, up_layer='upsample'):\n",
" super(UnetSkipConnectionBlock, self).__init__()\n",
" self.outermost = outermost\n",
" if input_nc is None:\n",
" input_nc = outer_nc\n",
" \n",
" downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1)\n",
" downrelu = nn.LeakyReLU(0.2, True)\n",
" downnorm = norm_layer(inner_nc)\n",
"\n",
" up_inner_nc = inner_nc if innermost else inner_nc * 2\n",
" up_conv_layer = nn.Conv2d(up_inner_nc, outer_nc, kernel_size=3, padding=1)\n",
" if up_layer == 'upsample':\n",
" upconv = [nn.UpsamplingNearest2d(scale_factor=2), up_conv_layer]\n",
" elif up_layer == 'transpose':\n",
" upconv = [nn.ConvTranspose2d(up_inner_nc, outer_nc,\n",
" kernel_size=4, stride=2,\n",
" padding=1)]\n",
" uprelu = nn.ReLU(True)\n",
" upnorm = norm_layer(outer_nc)\n",
"\n",
" if outermost:\n",
" down = [downconv]\n",
" up = [uprelu] + upconv + [nn.Tanh()]\n",
" model = down + [submodule] + up\n",
" elif innermost:\n",
" down = [downrelu, downconv]\n",
" up = [uprelu] + upconv + [upnorm]\n",
" model = down + up\n",
" else:\n",
" down = [downrelu, downconv, downnorm]\n",
" up = [uprelu] + upconv + [upnorm]\n",
" model = down + [submodule] + up\n",
"\n",
" self.model = nn.Sequential(*model)\n",
"\n",
" def forward(self, x):\n",
" if self.outermost:\n",
" return self.model(x)\n",
" else:\n",
" return torch.cat([x, self.model(x)], 1)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def train(m, it, half, shape):\n",
" x = torch.randn(shape).cuda()\n",
" m = m.cuda()\n",
" \n",
" if half:\n",
" x = x.half()\n",
" m = m.half()\n",
" \n",
" for i in range(it):\n",
" out = m(x)\n",
" loss = out.sum()\n",
" loss.backward()\n",
" if hasattr(m, 'zero_grad'):\n",
" m.zero_grad()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model - resnet"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import torchvision"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"shape = [32,3,224,224]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"m = torchvision.models.resnet50()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 9.12 s, sys: 4.59 s, total: 13.7 s\n",
"Wall time: 15.2 s\n"
]
}
],
"source": [
"%time train(m, it=100, half=False, shape=shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Half"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"m = torchvision.models.resnet50()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4.03 s, sys: 1.71 s, total: 5.74 s\n",
"Wall time: 7.14 s\n"
]
}
],
"source": [
"%time train(m, it=100, half=True, shape=shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model - upsample"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"m = UnetGenerator(shape[1],shape[1],7,up_layer='upsample')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ubuntu/anaconda3/envs/pose/lib/python3.7/site-packages/torch/nn/modules/upsampling.py:129: UserWarning: nn.UpsamplingNearest2d is deprecated. Use nn.functional.interpolate instead.\n",
" warnings.warn(\"nn.{} is deprecated. Use nn.functional.interpolate instead.\".format(self.name))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1.07 s, sys: 1.12 s, total: 2.18 s\n",
"Wall time: 5.63 s\n"
]
}
],
"source": [
"%time train(m, it=20, half=False, shape=shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Half"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"m = UnetGenerator(shape[1],shape[1],7,up_layer='upsample')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1.7 s, sys: 1.85 s, total: 3.55 s\n",
"Wall time: 8.5 s\n"
]
}
],
"source": [
"%time train(m, it=20, half=True, shape=shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model - transpose"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"shape = [32,3,256,256]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"m = UnetGenerator(shape[1],shape[1],7,up_layer='transpose')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 3.52 s, sys: 2.62 s, total: 6.14 s\n",
"Wall time: 6.81 s\n"
]
}
],
"source": [
"%time train(m, it=50, half=False, shape=shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Half"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"m = UnetGenerator(shape[1],shape[1],7,up_layer='transpose')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2.42 s, sys: 1.65 s, total: 4.07 s\n",
"Wall time: 4.81 s\n"
]
}
],
"source": [
"%time train(m, it=50, half=True, shape=shape)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment