Skip to content

Instantly share code, notes, and snippets.

@bearpelican
Created December 28, 2018 00:48
Show Gist options
  • Select an option

  • Save bearpelican/bbd6f2f027e78c7888f9ff44031eb0ea to your computer and use it in GitHub Desktop.

Select an option

Save bearpelican/bbd6f2f027e78c7888f9ff44031eb0ea to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.functional as F\n",
"torch.backends.cudnn.benchmark = True\n",
"from functools import partial\n",
"import functools\n",
"\n",
"x = torch.randn(1).cuda()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def train(m, it, half, shape):\n",
" x = torch.randn(shape).cuda()\n",
" m = m.cuda()\n",
" \n",
" if half:\n",
" x = x.half()\n",
" m = m.half()\n",
" \n",
" for i in range(it):\n",
" out = m(x)\n",
" loss = out.sum()\n",
" loss.backward()\n",
" if hasattr(m, 'zero_grad'):\n",
" m.zero_grad()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def upsample(shape, num_blocks=10):\n",
" block = lambda: [nn.Conv2d(shape[1], shape[1], kernel_size=3, stride=1, padding=1)]\n",
" layers = []\n",
" for i in range(num_blocks): layers.extend(block())\n",
" return nn.Sequential(*layers)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 512 Filters"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"shape = [64,512,64,64]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"m_upsample = upsample(shape)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 7.72 s, sys: 4.97 s, total: 12.7 s\n",
"Wall time: 14.5 s\n"
]
}
],
"source": [
"%time train(m_upsample, it=20, shape=shape, half=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Half"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"m_upsample = upsample(shape)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 7 s, sys: 4.19 s, total: 11.2 s\n",
"Wall time: 13.2 s\n"
]
}
],
"source": [
"%time train(m_upsample, it=20, shape=shape, half=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 8 Filters"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"shape = [16,8,512,512]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"m_upsample = upsample(shape)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4.08 s, sys: 3.47 s, total: 7.55 s\n",
"Wall time: 7.67 s\n"
]
}
],
"source": [
"%time train(m_upsample, it=50, shape=shape, half=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Half - slower performance"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"m_upsample = upsample(shape)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4.99 s, sys: 4.02 s, total: 9.01 s\n",
"Wall time: 9.37 s\n"
]
}
],
"source": [
"%time train(m_upsample, it=50, shape=shape, half=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment