Skip to content

Instantly share code, notes, and snippets.

@daveluo
Last active May 13, 2019 16:35
Show Gist options
  • Save daveluo/723eb24e15814435fdd42e7d62f72458 to your computer and use it in GitHub Desktop.
Save daveluo/723eb24e15814435fdd42e7d62f72458 to your computer and use it in GitHub Desktop.
Demo of CPU-only Predictions and Pytorch Model Saving/Loading, 5/9/2018
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Demo of CPU-only Predictions and Pytorch Model Saving/Loading, 5/9/2018"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Adapted from CIFAR 10 Darknet notebook:\n",
"https://github.com/fastai/fastai/blob/master/courses/dl2/cifar10-darknet.ipynb"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Train on GPU"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%reload_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from fastai.conv_learner import *\n",
"PATH = Path(\"data/cifar10/\")\n",
"os.makedirs(PATH,exist_ok=True)\n",
"torch.cuda.set_device(0)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
"stats = (np.array([ 0.4914 , 0.48216, 0.44653]), np.array([ 0.24703, 0.24349, 0.26159]))\n",
"\n",
"num_workers = num_cpus()//2\n",
"bs=256\n",
"sz=32"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomFlip()], pad=sz//8)\n",
"data = ImageClassifierData.from_paths(PATH, val_name='test', tfms=tfms, bs=bs)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def conv_layer(ni, nf, ks=3, stride=1):\n",
" return nn.Sequential(\n",
" nn.Conv2d(ni, nf, kernel_size=ks, bias=False, stride=stride, padding=ks//2),\n",
" nn.BatchNorm2d(nf, momentum=0.01),\n",
" nn.LeakyReLU(negative_slope=0.1, inplace=True))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class ResLayer(nn.Module):\n",
" def __init__(self, ni):\n",
" super().__init__()\n",
" self.conv1=conv_layer(ni, ni//2, ks=1)\n",
" self.conv2=conv_layer(ni//2, ni, ks=3)\n",
" \n",
" def forward(self, x): return x.add(self.conv2(self.conv1(x)))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class Darknet(nn.Module):\n",
" def make_group_layer(self, ch_in, num_blocks, stride=1):\n",
" return [conv_layer(ch_in, ch_in*2,stride=stride)\n",
" ] + [(ResLayer(ch_in*2)) for i in range(num_blocks)]\n",
"\n",
" def __init__(self, num_blocks, num_classes, nf=32):\n",
" super().__init__()\n",
" layers = [conv_layer(3, nf, ks=3, stride=1)]\n",
" for i,nb in enumerate(num_blocks):\n",
" layers += self.make_group_layer(nf, nb, stride=2-(i==1))\n",
" nf *= 2\n",
" layers += [nn.AdaptiveAvgPool2d(1), Flatten(), nn.Linear(nf, num_classes)]\n",
" self.layers = nn.Sequential(*layers)\n",
" \n",
" def forward(self, x): return self.layers(x)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"m = Darknet([1, 2, 4, 6, 3], num_classes=10, nf=32)\n",
"# m = nn.DataParallel(m, [1,2,3])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"lr = 1.3"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"learn = ConvLearner.from_model_data(m, data)\n",
"learn.crit = nn.CrossEntropyLoss()\n",
"learn.metrics = [accuracy]\n",
"wd=1e-4"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9deecacff9a143dfa8c9800a1a7c0342",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
"<p>\n",
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
" that the widgets JavaScript is still loading. If this message persists, it\n",
" likely means that the widgets JavaScript library is either not installed or\n",
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
" Widgets Documentation</a> for setup instructions.\n",
"</p>\n",
"<p>\n",
" If you're reading this message in another frontend (for example, a static\n",
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
" it may mean that your frontend doesn't currently support widgets.\n",
"</p>\n"
],
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch', max=5), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch trn_loss val_loss accuracy \n",
" 0 2.191378 8.647105 0.1572 \n",
" 1 1.884872 2.134262 0.1963 \n",
" 2 1.657069 2.075285 0.2077 \n",
" 3 1.41754 1.609488 0.4278 \n",
" 4 1.301417 1.286784 0.5346 \n",
"\n",
"CPU times: user 4min 6s, sys: 1min 30s, total: 5min 37s\n",
"Wall time: 4min 31s\n"
]
},
{
"data": {
"text/plain": [
"[array([1.28678]), 0.5346]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time learn.fit(lr, 1, wds=wd, cycle_len=5, use_clr_beta=(20, 20, 0.95, 0.85))"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'test/airplane/6633_airplane.png'"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.val_ds.fnames[0]"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(3, 32, 32)"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_img = data.val_ds[0][0]\n",
"test_img.shape"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7f934c54f3c8>"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvFvnyVgAAGzhJREFUeJztnW2MnFd1x/9nXnZ2vbv22uuXOLaDE+OShLySVRKaigYoNA1ICSogkBoFKcKoIlKR6IcorUoq9QNUBcSHiso0EWlFCSmEkqKoJY1AES1NsgnBMXEgieME2+tdO971vs3Ozsvph5lUm839352d3Z2xuf+ftNrZe+Y+98x9njPP7P3POdfcHUKI9Mh02gEhRGdQ8AuRKAp+IRJFwS9Eoij4hUgUBb8QiaLgFyJRFPxCJIqCX4hEya2ks5ndBOBrALIA/tHdvxgdrKffC/2DYWPsm4a2bAMAfrxz/UuNZrHXHSY6vbHDtTAWAGRJv3yW329K1Rq11SIvIPYtVWZq7VXF75aZyDUH56+tuyschuvWddM++Xy4z9jIMUxOnG7q5bUc/GaWBfD3AD4A4CiAp8zsYXd/nvUp9A/iko/+ZdhY45PDzlQsCNyr1Fat8pNkkcuipfeMFoMnRo6ceIBf7LXYa474mO3Kc0ci/TYUwmGyta+H9nllcpbaZsr8fM5XKtTmxJSJnOds5M2kJ/IG1YMytaHCX9vvXLAt2H7NVW+nfbbv2Bps//ynbuU+LGIlH/uvBfCSux9293kADwC4ZQXHE0K0kZUE/w4Av1nw99FGmxDiHGAlwR/63PSWz0tmts/Mhs1suFKcWsFwQojVZCXBfxTArgV/7wRwfPGT3H2/uw+5+1Cup38FwwkhVpOVBP9TAPaa2YVm1gXgEwAeXh23hBBrTcur/e5eMbM7Afwn6lLffe7+y5Y9Mf4+ZGTx1SPr73UxgtmWrywAcUmJkcnEDshNuRw/Nbk8P2alMh/2A3w+MpG5z0RUmFpEtpsntlO1yKp9lvvYU+OvuWeOr/b3z0wH2zfNTtI+g5OnqK1QfI3aql38mJNZ7mP5RPicHbX30j6btvxJsH051+iKdH53fwTAIys5hhCiM+gbfkIkioJfiERR8AuRKAp+IRJFwS9Eoqxotb8VWC5ITKJgYlMsCSeWBRaT82J5OF3ZsNFrrWWcZTKR996I+/MlLhtViSRWy0Skz4gc2Zvn/bqrXLZbNx7+NmffLE9w2VB6ndoGJ05S26Yit/VUw8csOe8zA+7HXBc1YY5k5wFAtosnNGW61gXb1w1s4INF5Nlm0Z1fiERR8AuRKAp+IRJFwS9Eoij4hUiUtq72O4AKSbjxFpJtWqllBwC52EpppNZaNl8ItndF/DDnY5Wr4YQOALBqidp6KvyYfR7ul5sNJ7gAQF/xNLVtKE9QW2FujNt8NNi+bYCvpPdt4n7M9XOVINPHz1l+fVg2GRvnq+/ZGX4+N/fwUl1bwpW1AAAT2c3Udir/qWD7BVdcyw9I1ZvmY0J3fiESRcEvRKIo+IVIFAW/EImi4BciURT8QiRKW6W+Lq9hZ2kmaKuS2nMAkCdJLllW3A9ALqIcZgpcrslmwv4BQGGqN9je6/x4AzZHbfkJXvNt0+w4tVXmuOxVrYYltskil+y8m/vYvSkmo0V28+kKy4DbLuddCuf1Udupme3UVrVd1JbZcFmwPXeMz+HV6w9R29t2PUdt/RuK1HZ0hl8jP3w17P9cMZa4Fh6rFtv5ahG68wuRKAp+IRJFwS9Eoij4hUgUBb8QiaLgFyJRViT1mdkRAFMAqgAq7j4Ue36+UsaO02/Zy7NOKSKJVcOyhme4+5lIDby+/rAcBgDo5tl0s6fCaVvVuRHaZzpSAy/jPNNuPMfr9M2ex9PHstYdbJ/x3bRPqfc8atu6iRetW7/xAu5Hf7gu3dF1/Hjj02HfAeDZI/ycXbGb+7+t7+pge7HwH7TP1q3/TW2bt0W2Gytz20CBXyPrLLw92OnJjbRPfiY8ViVSV3Exq6Hzv9fd+eZmQoizEn3sFyJRVhr8DuBHZva0me1bDYeEEO1hpR/7b3D342a2FcCjZvaCuz++8AmNN4V9ANDXs36FwwkhVosV3fnd/Xjj9xiA7wN4S90hd9/v7kPuPtRdCC8CCSHaT8vBb2a9Ztb/xmMAHwRwcLUcE0KsLSv52L8NwPcbRTRzAP7F3bl+AmA2l8fwlnB2Vqx+Z6+TjLQcL8JoFi62CQCXd3FZsZLrp7ZDfXuC7d1lvq3SrG+itmKFZ20VwCXHvl4uiW0tPR1sH8+8nfZZV+A+9hf4WLUpLit1FcMnNDfAC1nOejhrEgCqlUhh1RyXD6dmwufaZsPbiQFAIc+zLS3Hs/Myzs9nD8nCAwCvhYuani5eSPsULDz31WpE415Ey8Hv7ocBXNlqfyFEZ5HUJ0SiKPiFSBQFvxCJouAXIlEU/EIkSlsLeJpl0GVhec6yXKIwDATbezORvd1yXH7rMf6y1xci8tWmsDyUqfEvL03P80KcnuWyUW+JZ4FNnuDHnDz2fLC9NPUi7XPRBduo7Zrrd1PbyRNcfht56lfh471zkPYZGuCS4+H1XCKcnuGy7lQlfD63r3uV9ikUeCZmFlyOnJnlxV/LFS71lSbDxU5fOH2U9tncF5Y358s8G3QxuvMLkSgKfiESRcEvRKIo+IVIFAW/EInS1tV+h6OM8ApxI0EoSIWssHZleXJGNhupZZbj20xdt+dn1Lbn8nBNtVKZ+24FXuGsO89Xh1EltQ4B/OTH/LX9nCTUjOd5olC1wP2wDeGVaAB4z2U8oeb0NeHt17Zs46ve/RtOUNveLF/RHz3JV7iL+N9gex58fk8c4/PxzJNc4di4masEg1v4OTszHk4kKs/xGo+ZAlHNvPnEHt35hUgUBb8QiaLgFyJRFPxCJIqCX4hEUfALkShtlfoAAzJEFosoFGZh45zzZI+NzpNf+nCS2l58ldv2XBmWVwYi8krVuRzZu4Fvx1Sb4xLb3t1cbho9EZabBkth6Q0Augq8puGRw7zf4HoubV10QVienSnyZKYzZ7h0mIvUNCxNcGmuNEdq+PVxyfHkGB/rp8P8fnnxpfy89J3Pk5aOWXjbs5kCvz5OZ8LyZqQs5FvQnV+IRFHwC5EoCn4hEkXBL0SiKPiFSBQFvxCJsqTUZ2b3AfgwgDF3v6zRtgnAdwDsBnAEwMfdI9raAjJG3m88sl8X0QHnnUtDluHy2/n9PAvsief4dl17rw77eOl5XM6bK/NpyUxxHyd+xusCdk/w9+zeybD8NjHF56rI5FcAucg8Ph/xP78zLBHmjMuDc2xbNgDzkQTIuYh8OFkLS3obN3Hfi9P8+vAq9//kzMXU9l+vvJ/aRqfCdQ1jIfWbufA1MB/Z1mwxzdz5vwngpkVtdwF4zN33Anis8bcQ4hxiyeB398cBLC6TewuA+xuP7wdw6yr7JYRYY1r9n3+bu48AQOP31tVzSQjRDtZ8wc/M9pnZsJkNV4r8f2MhRHtpNfhHzWw7ADR+0y+iu/t+dx9y96FcD19ME0K0l1aD/2EAtzce3w7gB6vjjhCiXTQj9X0bwI0ANpvZUQBfAPBFAA+a2R0AXgPwsaZHXEaBwSUPZVzWKGf4dl1jRV4M8tQ0l71OngwX45wwnhV3usiz4jZuf53aJkuR9+VwciEAoHcwLHt1zUfkq0hKZQ288ORpnhiHCXLMjf0RKSpiqmS4xFYEP2eTlfB8vH6GZ9nNZC+ktm1XvoPaNu/9ELUdPM3Hy1j4g3MG/DWzqbJl3M+XDH53/yQxceFSCHHWo2/4CZEoCn4hEkXBL0SiKPiFSBQFvxCJ0uYCnoATqS+6xxjZxy9b5TLUuhLf5+yFIy9S2+FIAc+pybCkdywiOU6VeTHIM7WIthWRxPLGT1uhFp7HngqfKycFUgGg4lxGc65EoUS+8F0LJ7ABAKYn+euarHH/T0z1UdvrxYuC7fN9XKzq695ObTj5a2qaPcivq/NL3bzfqfBefcUqPy/TfWHpMFPjGY5veW7TzxRC/Fah4BciURT8QiSKgl+IRFHwC5EoCn4hEqXNUp9TqS+GEakvV+RpZfnpEWobH+dySG8/lwhfX1zMrEFxhL+mQl+kYOVRnmnXv55LbAPrI8UnZ4lEGCk8WSvy45XGIhLhPL93HDmeD7af6OaS16lxnok5WgkfDwDKo+F9AQGgMr8+2G65l2if09VD1Fac4dcV8nw+8v3cx3IhPMe1HJ+rru7Lg+0Zl9QnhFgCBb8QiaLgFyJRFPxCJIqCX4hEafNqv8FJvTWL1JFjdf/K3bwW32jubdTWk+d12PYMvkZtI6+ElYBSiU9jqcZt3Tm+kr5zJzXhnZfyIn79A+H389en+H5Xx17dTG0nj/J6dtWuAWo7NBO2VbsiBQi7+Io4NvCV70w3X+HOkdtbPs/Vg0yBX1eZyDVXy/FjTte42jJbDdd5zEZW+9eTxJ5s10O0z2J05xciURT8QiSKgl+IRFHwC5EoCn4hEkXBL0SiNLNd130APgxgzN0va7TdA+DTAN4oeHe3uz/S3JAsYaWFbbyy3P1ansskpe5rqO34RDgRBADG50aD7ZUuLvH46YPU1j3zArUdO7OD2i559/nUdt1lx4Ptr43wmoBPDX6A2sa2XEdt05HLp5YN31cyPF8JPRGpbHAdlwhzkVqCOeKHZbkj+Ty3ZbN8sEqkTmKtxOd/IBP2MRdJgrNseCyz5u/nzTzzmwBuCrR/1d2vavw0GfhCiLOFJYPf3R8HQJJZhRDnKiv5n/9OMztgZveZ2cZV80gI0RZaDf6vA9gD4CoAIwC+zJ5oZvvMbNjMhivFqRaHE0KsNi0Fv7uPunvV3WsAvgHg2shz97v7kLsP5Xr6W/VTCLHKtBT8ZrZwS5OPAOBL2kKIs5JmpL5vA7gRwGYzOwrgCwBuNLOrUNfnjgD4zIo9IXX64nAppAtcWikZ395pqo9LW+N94WMWamO0z/kz/H2xht3UNnfBbdR28MQJarvxynBW4rrBXbTPmYj0WazyuaqWw9uXAUCmFq5P6JFzVovYKnzXM3ikdl65Qq4r59fHPJHeACCb5yGTyXH/ewuR7E5im5rhL7pCtqpbjmC+ZPC7+ycDzfcuYwwhxFmIvuEnRKIo+IVIFAW/EImi4BciURT8QiRKmwt4ckWvJaEvkvXUbXwrrHnnL3t9PlxMEQC2Vf8n2N4z/SztM1sLZwICAM77fWrybeHtmADgyDiX38bOhAs7XvmOSdrnj3M/pLZHD1xGbS9NXExtTtL3olJURGIr13hPr0aOSg7pPAEPsQ2vchGJsBC5ii0i9XWRYqLdZBsvAJguRl5Ak+jOL0SiKPiFSBQFvxCJouAXIlEU/EIkioJfiERpu9TH1JCYbOdE5slGpJVKLpK1VVtHbTsKL1PbxblwqcIz2TPcD++ittwcL+A5NXWM2mY2bqC2UiVc3HNj9Qna57o+PtbU+sPUdiJS7HTWWRYhF9JKZS7PxjTC7qhQTK6DWHFMcD/ykZDJZritOMclZFZ0M6Zglpf/st6C7vxCJIqCX4hEUfALkSgKfiESRcEvRKK0f7WfrEZGVymJsRKp03eyxLfr8gzvVynzlfsSqZtWjUxjJc8rFk/3RuoFOl/R35rlK/DlE68E2185UqB9ZiJ7aJ2e5Sv68+DHtEx4rrKR7aQ8smpfi+SxzEeuA29htT8f2f4rE+lX4W5gtsQVhNL8NO9ImCOD1Zax3K87vxCJouAXIlEU/EIkioJfiERR8AuRKAp+IRKlme26dgH4JwDnoZ4lsd/dv2ZmmwB8B8Bu1Lfs+ri7jy89ZAu13ag8xHtVnb+v1ZxvgzRWvpDaiqND4faRJ2kf3/671Jbp57ZqRLI5PjZAbf82FU4k2tU3S/uMZy+gtp9P/SG1zWfD9QIBoJAJ+2/GdbRYolY2sp1bNnILY0kzHqnFl4mNFavTF7mIKxVunK+GZcBMpKYhlSpXWeqrAPi8u18C4HoAnzWzSwHcBeAxd98L4LHG30KIc4Qlg9/dR9z9mcbjKQCHAOwAcAuA+xtPux/ArWvlpBBi9VnW//xmthvA1QCeALDN3UeA+hsEgK2r7ZwQYu1oOvjNrA/A9wB8zt15Efi39ttnZsNmNlwpTrXioxBiDWgq+M0sj3rgf8vdH2o0j5rZ9oZ9O4DgJvXuvt/dh9x9KNfDv+cuhGgvSwa/mRmAewEccvevLDA9DOD2xuPbAfxg9d0TQqwVzWT13QDgNgDPmdkb+1LdDeCLAB40szsAvAbgYytzpZUNuyJ9IrpLtsblpmKGfzqZ2fChYPt85UraZ2Azlw5zETkyU52jtpM+SG3//vLNwfbaaz+hfbp3vZvasG0PNRUi2ZFUmovIV5Fkumg2YC5yzAyTHCNzH9OdyeEAANUKz9yziHxYJddxNVLELyYrNsuSwe/uPwWPsvev3AUhRCfQN/yESBQFvxCJouAXIlEU/EIkioJfiERpfwHPs4BYocgM2RoMAHLreoPt6zdcwftEZKhymWcXuvFtvuB866fc4NuD7dXenbSPrePbl+Uj257F5SYyx5H5jSm3HtHfaqSwKgAYcT+WMZdpSXYGapHXZpEiqfNEISxHXhfL+oxlgy5Gd34hEkXBL0SiKPiFSBQFvxCJouAXIlEU/EIkSppSH9N/AFQ9sl8cKcI4F5HeYpJSTJXJZHiOW1ekXy0bNub7+2ifbGSsWKZdFFpfMjL3seNFblOZLD9nTrLpYnvaxWTFVvaUBIBaZCO/eSIRlpzPCFMVY0rqYnTnFyJRFPxCJIqCX4hEUfALkSgKfiESJcnVfous6MdWehm1Gl/JjdlixFbZo9s4kaSUSAm5llf7q5HEE+ph1HdOrAaex1buWQJM5HjRM9bCWEA86adaI75EzlmG1DRcTkqS7vxCJIqCX4hEUfALkSgKfiESRcEvRKIo+IVIlCWlPjPbBeCfAJyHugqy392/Zmb3APg0gJONp97t7o+slaOrSkwPaWEbpFiSSGvV4IBKREbLRuQyJolZTGKL7XoWNUa2oCJzQrfxApDLcmExJpnGat3RPKJITb2oLhohllQTkw+zJDEpdilWyHws5/JtRuevAPi8uz9jZv0AnjazRxu2r7r73y1jPCHEWUIze/WNABhpPJ4ys0MAdqy1Y0KItWVZ//Ob2W4AVwN4otF0p5kdMLP7zGzjKvsmhFhDmg5+M+sD8D0An3P3SQBfB7AHwFWofzL4Mum3z8yGzWy4UpxaBZeFEKtBU8FvZnnUA/9b7v4QALj7qLtXvV6a5RsArg31dff97j7k7kO5nv7V8lsIsUKWDH6rLx/fC+CQu39lQfv2BU/7CICDq++eEGKtaGa1/wYAtwF4zsyebbTdDeCTZnYV6urCEQCfWRMP20wseyzSi1riWYKt2WKyV5bIZZkWs+LmIzX3ahEbO2Zsq7RWMxmdbHcFAGXy0mJbjcVU0ZhkV4u8tphkms+FB8zE5M1KmViaF/uaWe3/KcJX97mh6QshgugbfkIkioJfiERR8AuRKAp+IRJFwS9EoiRZwPO3mVYLhp4NxyuDa3bZmIwWK8ZJxotlYjJ5EIjLojE/crmIkMkyILP83sxsy5GqdecXIlEU/EIkioJfiERR8AuRKAp+IRJFwS9EokjqW3NaLeHJieYCkiqStRovctky0RqYYWMsE7Bc4bZI4l40448VLq05nw9WfBRYQupje+4t0c9Jv+jrWoXrSnd+IRJFwS9Eoij4hUgUBb8QiaLgFyJRFPxCJIqkvo7SmlwTLdG4+spii5ACnmuQMZeN9gu3R1S5qJOxAqSx81KLbORXJaZslvfJMD+WsVmf7vxCJIqCX4hEUfALkSgKfiESRcEvRKIsudpvZt0AHgdQaDz/u+7+BTO7EMADADYBeAbAbe4+v5bOinMHvrq9jOXoZseKHDOyyRc/YCxhyVvdmo3DVI5aNVI/ceWL/U3d+UsA3ufuV6K+HfdNZnY9gC8B+Kq77wUwDuCOZYwrhOgwSwa/15lu/Jlv/DiA9wH4bqP9fgC3romHQog1oan/+c0s29ihdwzAowBeBjDh/v/7ox4FsGNtXBRCrAVNBb+7V939KgA7AVwL4JLQ00J9zWyfmQ2b2XClONW6p0KIVWVZq/3uPgHgJwCuBzBgZm8sGO4EcJz02e/uQ+4+lOvpX4mvQohVZMngN7MtZjbQeNwD4A8AHALwYwAfbTztdgA/WCsnhRCrTzOJPdsB3G9mWdTfLB509x+a2fMAHjCzvwHwcwD3rqGf4hyDJ+K0WOculhAUEbiY4hhR7FpOPlptYmNxJbV5/5YMfnc/AODqQPth1P//F0Kcg+gbfkIkioJfiERR8AuRKAp+IRJFwS9Eolg7pQszOwng1cafmwGcatvgHPnxZuTHmznX/Hibu29p5oBtDf43DWw27O5DHRlcfsgP+aGP/UKkioJfiETpZPDv7+DYC5Efb0Z+vJnfWj869j+/EKKz6GO/EInSkeA3s5vM7Fdm9pKZ3dUJHxp+HDGz58zsWTMbbuO495nZmJkdXNC2ycweNbMXG783dsiPe8zsWGNOnjWzm9vgxy4z+7GZHTKzX5rZnzXa2zonET/aOidm1m1mT5rZLxp+/HWj/UIze6IxH98xs64VDeTubf0BkEW9DNhFALoA/ALApe32o+HLEQCbOzDuewC8C8DBBW1/C+CuxuO7AHypQ37cA+DP2zwf2wG8q/G4H8CvAVza7jmJ+NHWOUE977mv8TgP4AnUC+g8COATjfZ/APCnKxmnE3f+awG85O6HvV7q+wEAt3TAj47h7o8DOL2o+RbUC6ECbSqISvxoO+4+4u7PNB5PoV4sZgfaPCcRP9qK11nzormdCP4dAH6z4O9OFv90AD8ys6fNbF+HfHiDbe4+AtQvQgBbO+jLnWZ2oPFvwZr/+7EQM9uNev2IJ9DBOVnkB9DmOWlH0dxOBH+ohkqnJIcb3P1dAP4IwGfN7D0d8uNs4usA9qC+R8MIgC+3a2Az6wPwPQCfc/fJdo3bhB9tnxNfQdHcZulE8B8FsGvB37T451rj7scbv8cAfB+drUw0ambbAaDxe6wTTrj7aOPCqwH4Bto0J2aWRz3gvuXuDzWa2z4nIT86NSeNsZddNLdZOhH8TwHY21i57ALwCQAPt9sJM+s1s/43HgP4IICD8V5rysOoF0IFOlgQ9Y1ga/ARtGFOrF7w714Ah9z9KwtMbZ0T5ke756RtRXPbtYK5aDXzZtRXUl8G8Bcd8uEi1JWGXwD4ZTv9APBt1D8+llH/JHQHgEEAjwF4sfF7U4f8+GcAzwE4gHrwbW+DH7+H+kfYAwCebfzc3O45ifjR1jkBcAXqRXEPoP5G81cLrtknAbwE4F8BFFYyjr7hJ0Si6Bt+QiSKgl+IRFHwC5EoCn4hEkXBL0SiKPiFSBQFvxCJouAXIlH+DzW4QNASn624AAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(data.val_ds.denorm(test_img)[0])"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"Variable containing:\n",
" 2.4417 -1.4941 1.2131 0.1524 0.1210 -0.3215 -0.9287 -1.0541 1.2611 -1.3499\n",
"[torch.cuda.FloatTensor of size 1x10 (GPU 0)]"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.model(V(test_img).unsqueeze_(0))"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"learn.save('cf10dn')"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"Darknet(\n",
" (layers): Sequential(\n",
" (0): Sequential(\n",
" (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(32, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (1): Sequential(\n",
" (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (2): ResLayer(\n",
" (conv1): Sequential(\n",
" (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(32, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (conv2): Sequential(\n",
" (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" )\n",
" (3): Sequential(\n",
" (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (4): ResLayer(\n",
" (conv1): Sequential(\n",
" (0): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (conv2): Sequential(\n",
" (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" )\n",
" (5): ResLayer(\n",
" (conv1): Sequential(\n",
" (0): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (conv2): Sequential(\n",
" (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" )\n",
" (6): Sequential(\n",
" (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (7): ResLayer(\n",
" (conv1): Sequential(\n",
" (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (conv2): Sequential(\n",
" (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" )\n",
" (8): ResLayer(\n",
" (conv1): Sequential(\n",
" (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (conv2): Sequential(\n",
" (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" )\n",
" (9): ResLayer(\n",
" (conv1): Sequential(\n",
" (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (conv2): Sequential(\n",
" (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" )\n",
" (10): ResLayer(\n",
" (conv1): Sequential(\n",
" (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (conv2): Sequential(\n",
" (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" )\n",
" (11): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (12): ResLayer(\n",
" (conv1): Sequential(\n",
" (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (conv2): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" )\n",
" (13): ResLayer(\n",
" (conv1): Sequential(\n",
" (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (conv2): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" )\n",
" (14): ResLayer(\n",
" (conv1): Sequential(\n",
" (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (conv2): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" )\n",
" (15): ResLayer(\n",
" (conv1): Sequential(\n",
" (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (conv2): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" )\n",
" (16): ResLayer(\n",
" (conv1): Sequential(\n",
" (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (conv2): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" )\n",
" (17): ResLayer(\n",
" (conv1): Sequential(\n",
" (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (conv2): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" )\n",
" (18): Sequential(\n",
" (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(1024, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (19): ResLayer(\n",
" (conv1): Sequential(\n",
" (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (conv2): Sequential(\n",
" (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(1024, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" )\n",
" (20): ResLayer(\n",
" (conv1): Sequential(\n",
" (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (conv2): Sequential(\n",
" (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(1024, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" )\n",
" (21): ResLayer(\n",
" (conv1): Sequential(\n",
" (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" (conv2): Sequential(\n",
" (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(1024, eps=1e-05, momentum=0.01, affine=True)\n",
" (2): LeakyReLU(0.1, inplace)\n",
" )\n",
" )\n",
" (22): AdaptiveAvgPool2d(output_size=1)\n",
" (23): Flatten(\n",
" )\n",
" (24): Linear(in_features=1024, out_features=10, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Load model on CPU, test prediction, use torch.save options\n",
"#### After saving model on GPU in step 1, download this notebook and the saved .h5 file (found in data/cifar10/models/). Put notebook and .h5 file into same relative locations on your CPU machine (assuming you have your data organized in the same paths) and start notebook from here:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%reload_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from fastai.conv_learner import *\n",
"PATH = Path(\"data/cifar10/\")\n",
"os.makedirs(PATH,exist_ok=True)\n",
"# torch.cuda.set_device(0)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
"stats = (np.array([ 0.4914 , 0.48216, 0.44653]), np.array([ 0.24703, 0.24349, 0.26159]))\n",
"\n",
"num_workers = num_cpus()//2\n",
"bs=256\n",
"sz=32"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomFlip()], pad=sz//8)\n",
"data = ImageClassifierData.from_paths(PATH, val_name='test', tfms=tfms, bs=bs)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def conv_layer(ni, nf, ks=3, stride=1):\n",
" return nn.Sequential(\n",
" nn.Conv2d(ni, nf, kernel_size=ks, bias=False, stride=stride, padding=ks//2),\n",
" nn.BatchNorm2d(nf, momentum=0.01),\n",
" nn.LeakyReLU(negative_slope=0.1, inplace=True))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class ResLayer(nn.Module):\n",
" def __init__(self, ni):\n",
" super().__init__()\n",
" self.conv1=conv_layer(ni, ni//2, ks=1)\n",
" self.conv2=conv_layer(ni//2, ni, ks=3)\n",
" \n",
" def forward(self, x): return x.add(self.conv2(self.conv1(x)))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class Darknet(nn.Module):\n",
" def make_group_layer(self, ch_in, num_blocks, stride=1):\n",
" return [conv_layer(ch_in, ch_in*2,stride=stride)\n",
" ] + [(ResLayer(ch_in*2)) for i in range(num_blocks)]\n",
"\n",
" def __init__(self, num_blocks, num_classes, nf=32):\n",
" super().__init__()\n",
" layers = [conv_layer(3, nf, ks=3, stride=1)]\n",
" for i,nb in enumerate(num_blocks):\n",
" layers += self.make_group_layer(nf, nb, stride=2-(i==1))\n",
" nf *= 2\n",
" layers += [nn.AdaptiveAvgPool2d(1), Flatten(), nn.Linear(nf, num_classes)]\n",
" self.layers = nn.Sequential(*layers)\n",
" \n",
" def forward(self, x): return self.layers(x)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"m = Darknet([1, 2, 4, 6, 3], num_classes=10, nf=32)\n",
"# m = nn.DataParallel(m, [1,2,3])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"lr = 1.3"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"learn = ConvLearner.from_model_data(m, data)\n",
"learn.crit = nn.CrossEntropyLoss()\n",
"learn.metrics = [accuracy]\n",
"wd=1e-4"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn.load('cf10dn')\n",
"learn.model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[523]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[i for i,o in enumerate(data.val_ds.fnames) if o == 'test/airplane/6633_airplane.png']"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'test/airplane/6633_airplane.png'"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.val_ds.fnames[523]"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(3, 32, 32)"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_img = data.val_ds[523][0]\n",
"test_img.shape"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x1c2462fd68>"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAGzhJREFUeJztnW2MnFd1x/9nXnZ2vbv22uuXOLaDE+OShLySVRKaigYoNA1ICSogkBoFKcKoIlKR6IcorUoq9QNUBcSHiso0EWlFCSmEkqKoJY1AES1NsgnBMXEgieME2+tdO971vs3Ozsvph5lUm839352d3Z2xuf+ftNrZe+Y+98x9njPP7P3POdfcHUKI9Mh02gEhRGdQ8AuRKAp+IRJFwS9Eoij4hUgUBb8QiaLgFyJRFPxCJIqCX4hEya2ks5ndBOBrALIA/tHdvxgdrKffC/2DYWPsm4a2bAMAfrxz/UuNZrHXHSY6vbHDtTAWAGRJv3yW329K1Rq11SIvIPYtVWZq7VXF75aZyDUH56+tuyschuvWddM++Xy4z9jIMUxOnG7q5bUc/GaWBfD3AD4A4CiAp8zsYXd/nvUp9A/iko/+ZdhY45PDzlQsCNyr1Fat8pNkkcuipfeMFoMnRo6ceIBf7LXYa474mO3Kc0ci/TYUwmGyta+H9nllcpbaZsr8fM5XKtTmxJSJnOds5M2kJ/IG1YMytaHCX9vvXLAt2H7NVW+nfbbv2Bps//ynbuU+LGIlH/uvBfCSux9293kADwC4ZQXHE0K0kZUE/w4Av1nw99FGmxDiHGAlwR/63PSWz0tmts/Mhs1suFKcWsFwQojVZCXBfxTArgV/7wRwfPGT3H2/uw+5+1Cup38FwwkhVpOVBP9TAPaa2YVm1gXgEwAeXh23hBBrTcur/e5eMbM7Afwn6lLffe7+y5Y9Mf4+ZGTx1SPr73UxgtmWrywAcUmJkcnEDshNuRw/Nbk8P2alMh/2A3w+MpG5z0RUmFpEtpsntlO1yKp9lvvYU+OvuWeOr/b3z0wH2zfNTtI+g5OnqK1QfI3aql38mJNZ7mP5RPicHbX30j6btvxJsH051+iKdH53fwTAIys5hhCiM+gbfkIkioJfiERR8AuRKAp+IRJFwS9Eoqxotb8VWC5ITKJgYlMsCSeWBRaT82J5OF3ZsNFrrWWcZTKR996I+/MlLhtViSRWy0Skz4gc2Zvn/bqrXLZbNx7+NmffLE9w2VB6ndoGJ05S26Yit/VUw8csOe8zA+7HXBc1YY5k5wFAtosnNGW61gXb1w1s4INF5Nlm0Z1fiERR8AuRKAp+IRJFwS9Eoij4hUiUtq72O4AKSbjxFpJtWqllBwC52EpppNZaNl8ItndF/DDnY5Wr4YQOALBqidp6KvyYfR7ul5sNJ7gAQF/xNLVtKE9QW2FujNt8NNi+bYCvpPdt4n7M9XOVINPHz1l+fVg2GRvnq+/ZGX4+N/fwUl1bwpW1AAAT2c3Udir/qWD7BVdcyw9I1ZvmY0J3fiESRcEvRKIo+IVIFAW/EImi4BciURT8QiRKW6W+Lq9hZ2kmaKuS2nMAkCdJLllW3A9ALqIcZgpcrslmwv4BQGGqN9je6/x4AzZHbfkJXvNt0+w4tVXmuOxVrYYltskil+y8m/vYvSkmo0V28+kKy4DbLuddCuf1Udupme3UVrVd1JbZcFmwPXeMz+HV6w9R29t2PUdt/RuK1HZ0hl8jP3w17P9cMZa4Fh6rFtv5ahG68wuRKAp+IRJFwS9Eoij4hUgUBb8QiaLgFyJRViT1mdkRAFMAqgAq7j4Ue36+UsaO02/Zy7NOKSKJVcOyhme4+5lIDby+/rAcBgDo5tl0s6fCaVvVuRHaZzpSAy/jPNNuPMfr9M2ex9PHstYdbJ/x3bRPqfc8atu6iRetW7/xAu5Hf7gu3dF1/Hjj02HfAeDZI/ycXbGb+7+t7+pge7HwH7TP1q3/TW2bt0W2Gytz20CBXyPrLLw92OnJjbRPfiY8ViVSV3Exq6Hzv9fd+eZmQoizEn3sFyJRVhr8DuBHZva0me1bDYeEEO1hpR/7b3D342a2FcCjZvaCuz++8AmNN4V9ANDXs36FwwkhVosV3fnd/Xjj9xiA7wN4S90hd9/v7kPuPtRdCC8CCSHaT8vBb2a9Ztb/xmMAHwRwcLUcE0KsLSv52L8NwPcbRTRzAP7F3bl+AmA2l8fwlnB2Vqx+Z6+TjLQcL8JoFi62CQCXd3FZsZLrp7ZDfXuC7d1lvq3SrG+itmKFZ20VwCXHvl4uiW0tPR1sH8+8nfZZV+A+9hf4WLUpLit1FcMnNDfAC1nOejhrEgCqlUhh1RyXD6dmwufaZsPbiQFAIc+zLS3Hs/Myzs9nD8nCAwCvhYuani5eSPsULDz31WpE415Ey8Hv7ocBXNlqfyFEZ5HUJ0SiKPiFSBQFvxCJouAXIlEU/EIkSlsLeJpl0GVhec6yXKIwDATbezORvd1yXH7rMf6y1xci8tWmsDyUqfEvL03P80KcnuWyUW+JZ4FNnuDHnDz2fLC9NPUi7XPRBduo7Zrrd1PbyRNcfht56lfh471zkPYZGuCS4+H1XCKcnuGy7lQlfD63r3uV9ikUeCZmFlyOnJnlxV/LFS71lSbDxU5fOH2U9tncF5Y358s8G3QxuvMLkSgKfiESRcEvRKIo+IVIFAW/EInS1tV+h6OM8ApxI0EoSIWssHZleXJGNhupZZbj20xdt+dn1Lbn8nBNtVKZ+24FXuGsO89Xh1EltQ4B/OTH/LX9nCTUjOd5olC1wP2wDeGVaAB4z2U8oeb0NeHt17Zs46ve/RtOUNveLF/RHz3JV7iL+N9gex58fk8c4/PxzJNc4di4masEg1v4OTszHk4kKs/xGo+ZAlHNvPnEHt35hUgUBb8QiaLgFyJRFPxCJIqCX4hEUfALkShtlfoAAzJEFosoFGZh45zzZI+NzpNf+nCS2l58ldv2XBmWVwYi8krVuRzZu4Fvx1Sb4xLb3t1cbho9EZabBkth6Q0Augq8puGRw7zf4HoubV10QVienSnyZKYzZ7h0mIvUNCxNcGmuNEdq+PVxyfHkGB/rp8P8fnnxpfy89J3Pk5aOWXjbs5kCvz5OZ8LyZqQs5FvQnV+IRFHwC5EoCn4hEkXBL0SiKPiFSBQFvxCJsqTUZ2b3AfgwgDF3v6zRtgnAdwDsBnAEwMfdI9raAjJG3m88sl8X0QHnnUtDluHy2/n9PAvsief4dl17rw77eOl5XM6bK/NpyUxxHyd+xusCdk/w9+zeybD8NjHF56rI5FcAucg8Ph/xP78zLBHmjMuDc2xbNgDzkQTIuYh8OFkLS3obN3Hfi9P8+vAq9//kzMXU9l+vvJ/aRqfCdQ1jIfWbufA1MB/Z1mwxzdz5vwngpkVtdwF4zN33Anis8bcQ4hxiyeB398cBLC6TewuA+xuP7wdw6yr7JYRYY1r9n3+bu48AQOP31tVzSQjRDtZ8wc/M9pnZsJkNV4r8f2MhRHtpNfhHzWw7ADR+0y+iu/t+dx9y96FcD19ME0K0l1aD/2EAtzce3w7gB6vjjhCiXTQj9X0bwI0ANpvZUQBfAPBFAA+a2R0AXgPwsaZHXEaBwSUPZVzWKGf4dl1jRV4M8tQ0l71OngwX45wwnhV3usiz4jZuf53aJkuR9+VwciEAoHcwLHt1zUfkq0hKZQ288ORpnhiHCXLMjf0RKSpiqmS4xFYEP2eTlfB8vH6GZ9nNZC+ktm1XvoPaNu/9ELUdPM3Hy1j4g3MG/DWzqbJl3M+XDH53/yQxceFSCHHWo2/4CZEoCn4hEkXBL0SiKPiFSBQFvxCJ0uYCnoATqS+6xxjZxy9b5TLUuhLf5+yFIy9S2+FIAc+pybCkdywiOU6VeTHIM7WIthWRxPLGT1uhFp7HngqfKycFUgGg4lxGc65EoUS+8F0LJ7ABAKYn+euarHH/T0z1UdvrxYuC7fN9XKzq695ObTj5a2qaPcivq/NL3bzfqfBefcUqPy/TfWHpMFPjGY5veW7TzxRC/Fah4BciURT8QiSKgl+IRFHwC5EoCn4hEqXNUp9TqS+GEakvV+RpZfnpEWobH+dySG8/lwhfX1zMrEFxhL+mQl+kYOVRnmnXv55LbAPrI8UnZ4lEGCk8WSvy45XGIhLhPL93HDmeD7af6OaS16lxnok5WgkfDwDKo+F9AQGgMr8+2G65l2if09VD1Fac4dcV8nw+8v3cx3IhPMe1HJ+rru7Lg+0Zl9QnhFgCBb8QiaLgFyJRFPxCJIqCX4hEafNqv8FJvTWL1JFjdf/K3bwW32jubdTWk+d12PYMvkZtI6+ElYBSiU9jqcZt3Tm+kr5zJzXhnZfyIn79A+H389en+H5Xx17dTG0nj/J6dtWuAWo7NBO2VbsiBQi7+Io4NvCV70w3X+HOkdtbPs/Vg0yBX1eZyDVXy/FjTte42jJbDdd5zEZW+9eTxJ5s10O0z2J05xciURT8QiSKgl+IRFHwC5EoCn4hEkXBL0SiNLNd130APgxgzN0va7TdA+DTAN4oeHe3uz/S3JAsYaWFbbyy3P1ansskpe5rqO34RDgRBADG50aD7ZUuLvH46YPU1j3zArUdO7OD2i559/nUdt1lx4Ptr43wmoBPDX6A2sa2XEdt05HLp5YN31cyPF8JPRGpbHAdlwhzkVqCOeKHZbkj+Ty3ZbN8sEqkTmKtxOd/IBP2MRdJgrNseCyz5u/nzTzzmwBuCrR/1d2vavw0GfhCiLOFJYPf3R8HQJJZhRDnKiv5n/9OMztgZveZ2cZV80gI0RZaDf6vA9gD4CoAIwC+zJ5oZvvMbNjMhivFqRaHE0KsNi0Fv7uPunvV3WsAvgHg2shz97v7kLsP5Xr6W/VTCLHKtBT8ZrZwS5OPAOBL2kKIs5JmpL5vA7gRwGYzOwrgCwBuNLOrUNfnjgD4zIo9IXX64nAppAtcWikZ395pqo9LW+N94WMWamO0z/kz/H2xht3UNnfBbdR28MQJarvxynBW4rrBXbTPmYj0WazyuaqWw9uXAUCmFq5P6JFzVovYKnzXM3ikdl65Qq4r59fHPJHeACCb5yGTyXH/ewuR7E5im5rhL7pCtqpbjmC+ZPC7+ycDzfcuYwwhxFmIvuEnRKIo+IVIFAW/EImi4BciURT8QiRKmwt4ckWvJaEvkvXUbXwrrHnnL3t9PlxMEQC2Vf8n2N4z/SztM1sLZwICAM77fWrybeHtmADgyDiX38bOhAs7XvmOSdrnj3M/pLZHD1xGbS9NXExtTtL3olJURGIr13hPr0aOSg7pPAEPsQ2vchGJsBC5ii0i9XWRYqLdZBsvAJguRl5Ak+jOL0SiKPiFSBQFvxCJouAXIlEU/EIkioJfiERpu9TH1JCYbOdE5slGpJVKLpK1VVtHbTsKL1PbxblwqcIz2TPcD++ittwcL+A5NXWM2mY2bqC2UiVc3HNj9Qna57o+PtbU+sPUdiJS7HTWWRYhF9JKZS7PxjTC7qhQTK6DWHFMcD/ykZDJZritOMclZFZ0M6Zglpf/st6C7vxCJIqCX4hEUfALkSgKfiESRcEvRKK0f7WfrEZGVymJsRKp03eyxLfr8gzvVynzlfsSqZtWjUxjJc8rFk/3RuoFOl/R35rlK/DlE68E2185UqB9ZiJ7aJ2e5Sv68+DHtEx4rrKR7aQ8smpfi+SxzEeuA29htT8f2f4rE+lX4W5gtsQVhNL8NO9ImCOD1Zax3K87vxCJouAXIlEU/EIkioJfiERR8AuRKAp+IRKlme26dgH4JwDnoZ4lsd/dv2ZmmwB8B8Bu1Lfs+ri7jy89ZAu13ag8xHtVnb+v1ZxvgzRWvpDaiqND4faRJ2kf3/671Jbp57ZqRLI5PjZAbf82FU4k2tU3S/uMZy+gtp9P/SG1zWfD9QIBoJAJ+2/GdbRYolY2sp1bNnILY0kzHqnFl4mNFavTF7mIKxVunK+GZcBMpKYhlSpXWeqrAPi8u18C4HoAnzWzSwHcBeAxd98L4LHG30KIc4Qlg9/dR9z9mcbjKQCHAOwAcAuA+xtPux/ArWvlpBBi9VnW//xmthvA1QCeALDN3UeA+hsEgK2r7ZwQYu1oOvjNrA/A9wB8zt15Efi39ttnZsNmNlwpTrXioxBiDWgq+M0sj3rgf8vdH2o0j5rZ9oZ9O4DgJvXuvt/dh9x9KNfDv+cuhGgvSwa/mRmAewEccvevLDA9DOD2xuPbAfxg9d0TQqwVzWT13QDgNgDPmdkb+1LdDeCLAB40szsAvAbgYytzpZUNuyJ9IrpLtsblpmKGfzqZ2fChYPt85UraZ2Azlw5zETkyU52jtpM+SG3//vLNwfbaaz+hfbp3vZvasG0PNRUi2ZFUmovIV5Fkumg2YC5yzAyTHCNzH9OdyeEAANUKz9yziHxYJddxNVLELyYrNsuSwe/uPwWPsvev3AUhRCfQN/yESBQFvxCJouAXIlEU/EIkioJfiERpfwHPs4BYocgM2RoMAHLreoPt6zdcwftEZKhymWcXuvFtvuB866fc4NuD7dXenbSPrePbl+Uj257F5SYyx5H5jSm3HtHfaqSwKgAYcT+WMZdpSXYGapHXZpEiqfNEISxHXhfL+oxlgy5Gd34hEkXBL0SiKPiFSBQFvxCJouAXIlEU/EIkSppSH9N/AFQ9sl8cKcI4F5HeYpJSTJXJZHiOW1ekXy0bNub7+2ifbGSsWKZdFFpfMjL3seNFblOZLD9nTrLpYnvaxWTFVvaUBIBaZCO/eSIRlpzPCFMVY0rqYnTnFyJRFPxCJIqCX4hEUfALkSgKfiESJcnVfous6MdWehm1Gl/JjdlixFbZo9s4kaSUSAm5llf7q5HEE+ph1HdOrAaex1buWQJM5HjRM9bCWEA86adaI75EzlmG1DRcTkqS7vxCJIqCX4hEUfALkSgKfiESRcEvRKIo+IVIlCWlPjPbBeCfAJyHugqy392/Zmb3APg0gJONp97t7o+slaOrSkwPaWEbpFiSSGvV4IBKREbLRuQyJolZTGKL7XoWNUa2oCJzQrfxApDLcmExJpnGat3RPKJITb2oLhohllQTkw+zJDEpdilWyHws5/JtRuevAPi8uz9jZv0AnjazRxu2r7r73y1jPCHEWUIze/WNABhpPJ4ys0MAdqy1Y0KItWVZ//Ob2W4AVwN4otF0p5kdMLP7zGzjKvsmhFhDmg5+M+sD8D0An3P3SQBfB7AHwFWofzL4Mum3z8yGzWy4UpxaBZeFEKtBU8FvZnnUA/9b7v4QALj7qLtXvV6a5RsArg31dff97j7k7kO5nv7V8lsIsUKWDH6rLx/fC+CQu39lQfv2BU/7CICDq++eEGKtaGa1/wYAtwF4zsyebbTdDeCTZnYV6urCEQCfWRMP20wseyzSi1riWYKt2WKyV5bIZZkWs+LmIzX3ahEbO2Zsq7RWMxmdbHcFAGXy0mJbjcVU0ZhkV4u8tphkms+FB8zE5M1KmViaF/uaWe3/KcJX97mh6QshgugbfkIkioJfiERR8AuRKAp+IRJFwS9EoiRZwPO3mVYLhp4NxyuDa3bZmIwWK8ZJxotlYjJ5EIjLojE/crmIkMkyILP83sxsy5GqdecXIlEU/EIkioJfiERR8AuRKAp+IRJFwS9EokjqW3NaLeHJieYCkiqStRovctky0RqYYWMsE7Bc4bZI4l40448VLq05nw9WfBRYQupje+4t0c9Jv+jrWoXrSnd+IRJFwS9Eoij4hUgUBb8QiaLgFyJRFPxCJIqkvo7SmlwTLdG4+spii5ACnmuQMZeN9gu3R1S5qJOxAqSx81KLbORXJaZslvfJMD+WsVmf7vxCJIqCX4hEUfALkSgKfiESRcEvRKIsudpvZt0AHgdQaDz/u+7+BTO7EMADADYBeAbAbe4+v5bOinMHvrq9jOXoZseKHDOyyRc/YCxhyVvdmo3DVI5aNVI/ceWL/U3d+UsA3ufuV6K+HfdNZnY9gC8B+Kq77wUwDuCOZYwrhOgwSwa/15lu/Jlv/DiA9wH4bqP9fgC3romHQog1oan/+c0s29ihdwzAowBeBjDh/v/7ox4FsGNtXBRCrAVNBb+7V939KgA7AVwL4JLQ00J9zWyfmQ2b2XClONW6p0KIVWVZq/3uPgHgJwCuBzBgZm8sGO4EcJz02e/uQ+4+lOvpX4mvQohVZMngN7MtZjbQeNwD4A8AHALwYwAfbTztdgA/WCsnhRCrTzOJPdsB3G9mWdTfLB509x+a2fMAHjCzvwHwcwD3rqGf4hyDJ+K0WOculhAUEbiY4hhR7FpOPlptYmNxJbV5/5YMfnc/AODqQPth1P//F0Kcg+gbfkIkioJfiERR8AuRKAp+IRJFwS9Eolg7pQszOwng1cafmwGcatvgHPnxZuTHmznX/Hibu29p5oBtDf43DWw27O5DHRlcfsgP+aGP/UKkioJfiETpZPDv7+DYC5Efb0Z+vJnfWj869j+/EKKz6GO/EInSkeA3s5vM7Fdm9pKZ3dUJHxp+HDGz58zsWTMbbuO495nZmJkdXNC2ycweNbMXG783dsiPe8zsWGNOnjWzm9vgxy4z+7GZHTKzX5rZnzXa2zonET/aOidm1m1mT5rZLxp+/HWj/UIze6IxH98xs64VDeTubf0BkEW9DNhFALoA/ALApe32o+HLEQCbOzDuewC8C8DBBW1/C+CuxuO7AHypQ37cA+DP2zwf2wG8q/G4H8CvAVza7jmJ+NHWOUE977mv8TgP4AnUC+g8COATjfZ/APCnKxmnE3f+awG85O6HvV7q+wEAt3TAj47h7o8DOL2o+RbUC6ECbSqISvxoO+4+4u7PNB5PoV4sZgfaPCcRP9qK11nzormdCP4dAH6z4O9OFv90AD8ys6fNbF+HfHiDbe4+AtQvQgBbO+jLnWZ2oPFvwZr/+7EQM9uNev2IJ9DBOVnkB9DmOWlH0dxOBH+ohkqnJIcb3P1dAP4IwGfN7D0d8uNs4usA9qC+R8MIgC+3a2Az6wPwPQCfc/fJdo3bhB9tnxNfQdHcZulE8B8FsGvB37T451rj7scbv8cAfB+drUw0ambbAaDxe6wTTrj7aOPCqwH4Bto0J2aWRz3gvuXuDzWa2z4nIT86NSeNsZddNLdZOhH8TwHY21i57ALwCQAPt9sJM+s1s/43HgP4IICD8V5rysOoF0IFOlgQ9Y1ga/ARtGFOrF7w714Ah9z9KwtMbZ0T5ke756RtRXPbtYK5aDXzZtRXUl8G8Bcd8uEi1JWGXwD4ZTv9APBt1D8+llH/JHQHgEEAjwF4sfF7U4f8+GcAzwE4gHrwbW+DH7+H+kfYAwCebfzc3O45ifjR1jkBcAXqRXEPoP5G81cLrtknAbwE4F8BFFYyjr7hJ0Si6Bt+QiSKgl+IRFHwC5EoCn4hEkXBL0SiKPiFSBQFvxCJouAXIlH+DzW4QNASn624AAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(data.val_ds.denorm(test_img)[0])"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Variable containing:\n",
" 2.4417 -1.4941 1.2131 0.1524 0.1210 -0.3216 -0.9287 -1.0541 1.2611 -1.3499\n",
"[torch.FloatTensor of size 1x10]"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.model(V(test_img).unsqueeze_(0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Confirmed same prediction as above with GPU:\n",
"```\n",
"Variable containing:\n",
" 2.4417 -1.4941 1.2131 0.1524 0.1210 -0.3215 -0.9287 -1.0541 1.2611 -1.3499\n",
"[torch.cuda.FloatTensor of size 1x10 (GPU 0)]\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Option 1 - Save full model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"torch.save(learn.model, 'cf10dn_cpufullmodel.pt')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Option 2 - Save weights"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"torch.save(learn.model.state_dict(), 'cf10dn_cpuweights.pt')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Restart kernel and test loading model from here"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Option 1 continued - Load full model (torch.load())\n",
"\n",
"#### Note: torch.save() saved the model, not the fastai ConvLearner object so learn2 == learn.model"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"learn2 = torch.load('cf10dn_cpufullmodel.pt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn2.eval()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[523]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[i for i,o in enumerate(data.val_ds.fnames) if o == 'test/airplane/6633_airplane.png']"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Variable containing:\n",
" 2.4417 -1.4941 1.2131 0.1524 0.1210 -0.3216 -0.9287 -1.0541 1.2611 -1.3499\n",
"[torch.FloatTensor of size 1x10]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_img = data.val_ds[523][0]\n",
"learn2(V(test_img).unsqueeze_(0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Confirmed same prediction as above with GPU:\n",
"```\n",
"Variable containing:\n",
" 2.4417 -1.4941 1.2131 0.1524 0.1210 -0.3215 -0.9287 -1.0541 1.2611 -1.3499\n",
"[torch.cuda.FloatTensor of size 1x10 (GPU 0)]\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Option 2 continued - First define model and then load weights (model.load_state_dict(torch.load()))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"m = Darknet([1, 2, 4, 6, 3], num_classes=10, nf=32)\n",
"learn3 = ConvLearner.from_model_data(m, data)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"learn3.model.load_state_dict(torch.load('cf10dn_cpuweights.pt'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn3.model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Variable containing:\n",
" 2.4417 -1.4941 1.2131 0.1524 0.1210 -0.3216 -0.9287 -1.0541 1.2611 -1.3499\n",
"[torch.FloatTensor of size 1x10]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_img = data.val_ds[523][0]\n",
"learn3.model(V(test_img).unsqueeze_(0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Confirmed same prediction as above with GPU:\n",
"```\n",
"Variable containing:\n",
" 2.4417 -1.4941 1.2131 0.1524 0.1210 -0.3215 -0.9287 -1.0541 1.2611 -1.3499\n",
"[torch.cuda.FloatTensor of size 1x10 (GPU 0)]\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:fastai-cpu]",
"language": "python",
"name": "conda-env-fastai-cpu-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
},
"toc": {
"colors": {
"hover_highlight": "#DAA520",
"navigate_num": "#000000",
"navigate_text": "#333333",
"running_highlight": "#FF0000",
"selected_highlight": "#FFD700",
"sidebar_border": "#EEEEEE",
"wrapper_background": "#FFFFFF"
},
"moveMenuLeft": true,
"nav_menu": {
"height": "266px",
"width": "252px"
},
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 4,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false,
"widenNotebook": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment