-
-
Save git-hub-tig/fabb97f65d5bd0ab43313dd0ecda5d3c to your computer and use it in GitHub Desktop.
pytorch_tips_yt_follow.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "pytorch_tips_yt_follow.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyP1YWlW87wmbehl+0r0EmNb", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/ejmejm/1baeddbbe48f58dbced9c019c25ebf71/pytorch_tips_yt_follow.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "EYpb4rz9GDFx" | |
}, | |
"source": [ | |
"# 7 PyTorch Tips You Should Know" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "xdIxsiZf-mAS" | |
}, | |
"source": [ | |
"import time\n", | |
"\n", | |
"import torch\n", | |
"from torch import nn" | |
], | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "WEXFLOIkAlrF" | |
}, | |
"source": [ | |
"# 1. Create Tensors Directly on the Target Device" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "3i0i1gk4-RkL", | |
"outputId": "81fe4095-592d-4025-88fc-c630a8534617" | |
}, | |
"source": [ | |
"start_time = time.time()\n", | |
"\n", | |
"for _ in range(100):\n", | |
" # Creating on the CPU, then transfering to the GPU\n", | |
" cpu_tensor = torch.ones((1000, 64, 64))\n", | |
" gpu_tensor = cpu_tensor.cuda()\n", | |
"\n", | |
"print('Total time: {:.3f}s'.format(time.time() - start_time))" | |
], | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Total time: 0.584s\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "69pMZsQs-Sto", | |
"outputId": "5d791f8d-2659-4876-87e8-ddf15528fc2d" | |
}, | |
"source": [ | |
"start_time = time.time()\n", | |
"\n", | |
"for _ in range(100):\n", | |
" # Creating on GPU directly\n", | |
" cpu_tensor = torch.ones((1000, 64, 64), device='cuda')\n", | |
"\n", | |
"print('Total time: {:.3f}s'.format(time.time() - start_time))" | |
], | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Total time: 0.009s\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "ta-sxuoPAyJQ" | |
}, | |
"source": [ | |
"# 2. Use `Sequential` Layers When Possible" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "f2r4DXtL-RPP" | |
}, | |
"source": [ | |
"class ExampleModel(nn.Module):\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
"\n", | |
" input_size = 2\n", | |
" output_size = 3\n", | |
" hidden_size = 16\n", | |
"\n", | |
" self.input_layer = nn.Linear(input_size, hidden_size)\n", | |
" self.input_activation = nn.ReLU()\n", | |
"\n", | |
" self.mid_layer = nn.Linear(hidden_size, hidden_size)\n", | |
" self.mid_activation = nn.ReLU()\n", | |
"\n", | |
" self.output_layer = nn.Linear(hidden_size, output_size)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" z = self.input_layer(x)\n", | |
" z = self.input_activation(z)\n", | |
" \n", | |
" z = self.mid_layer(z)\n", | |
" z = self.mid_activation(z)\n", | |
" \n", | |
" out = self.output_layer(z)\n", | |
"\n", | |
" return out" | |
], | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "dOf_q_JT-V82", | |
"outputId": "9bc666a9-d040-4ddf-9fff-e32d6f24387c" | |
}, | |
"source": [ | |
"example_model = ExampleModel()\n", | |
"print(example_model)\n", | |
"print('Output shape:', example_model(torch.ones([100, 2])).shape)" | |
], | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"ExampleModel(\n", | |
" (input_layer): Linear(in_features=2, out_features=16, bias=True)\n", | |
" (input_activation): ReLU()\n", | |
" (mid_layer): Linear(in_features=16, out_features=16, bias=True)\n", | |
" (mid_activation): ReLU()\n", | |
" (output_layer): Linear(in_features=16, out_features=3, bias=True)\n", | |
")\n", | |
"Output shape: torch.Size([100, 3])\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "pCHY0KtT-WGj" | |
}, | |
"source": [ | |
"class ExampleSequentialModel(nn.Module):\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
"\n", | |
" input_size = 2\n", | |
" output_size = 3\n", | |
" hidden_size = 16\n", | |
"\n", | |
" self.layers = nn.Sequential(\n", | |
" nn.Linear(input_size, hidden_size),\n", | |
" nn.ReLU(),\n", | |
" nn.Linear(hidden_size, hidden_size),\n", | |
" nn.ReLU(),\n", | |
" nn.Linear(hidden_size, output_size))\n", | |
"\n", | |
" def forward(self, x):\n", | |
" out = self.layers(x)\n", | |
" return out" | |
], | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "m1PVgMcT-WNT", | |
"outputId": "c1c09ad7-50fe-46e9-9c7a-6d05654c1e66" | |
}, | |
"source": [ | |
"example_seq_model = ExampleSequentialModel()\n", | |
"print(example_seq_model)\n", | |
"print('Output shape:', example_seq_model(torch.ones([100, 2])).shape)" | |
], | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"ExampleSequentialModel(\n", | |
" (layers): Sequential(\n", | |
" (0): Linear(in_features=2, out_features=16, bias=True)\n", | |
" (1): ReLU()\n", | |
" (2): Linear(in_features=16, out_features=16, bias=True)\n", | |
" (3): ReLU()\n", | |
" (4): Linear(in_features=16, out_features=3, bias=True)\n", | |
" )\n", | |
")\n", | |
"Output shape: torch.Size([100, 3])\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "OLGRA4CyAztx" | |
}, | |
"source": [ | |
"# 3. Don't Make Lists of Layers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "uI2xZ3EP-Xkp" | |
}, | |
"source": [ | |
"class BadListModel(nn.Module):\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
"\n", | |
" input_size = 2\n", | |
" output_size = 3\n", | |
" hidden_size = 16\n", | |
"\n", | |
" self.input_layer = nn.Linear(input_size, hidden_size)\n", | |
" self.input_activation = nn.ReLU()\n", | |
"\n", | |
" # Fairly common when using residual layers\n", | |
" self.mid_layers = []\n", | |
" for _ in range(5):\n", | |
" self.mid_layers.append(nn.Linear(hidden_size, hidden_size))\n", | |
" self.mid_layers.append(nn.ReLU())\n", | |
"\n", | |
" self.output_layer = nn.Linear(hidden_size, output_size)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" z = self.input_layer(x)\n", | |
" z = self.input_activation(z)\n", | |
" \n", | |
" for layer in self.mid_layers:\n", | |
" z = layer(z)\n", | |
" \n", | |
" out = self.output_layer(z)\n", | |
"\n", | |
" return out" | |
], | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "48GAPovU-Xe1", | |
"outputId": "fe2978cd-8f8e-4af0-9cee-8de88764f0cb" | |
}, | |
"source": [ | |
"bad_list_model = BadListModel()\n", | |
"print('Output shape:', bad_list_model(torch.ones([100, 2])).shape)" | |
], | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Output shape: torch.Size([100, 3])\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 358 | |
}, | |
"id": "t6n32oG_-XXp", | |
"outputId": "4fd28d4c-0283-4922-dd4f-242088241632" | |
}, | |
"source": [ | |
"gpu_input = torch.ones([100, 2], device='cuda')\n", | |
"gpu_bad_list_model = bad_list_model.cuda()\n", | |
"print('Output shape:', bad_list_model(gpu_input).shape)" | |
], | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "error", | |
"ename": "RuntimeError", | |
"evalue": "ignored", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-13-e523900f19d5>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mgpu_input\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'cuda'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mgpu_bad_list_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbad_list_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Output shape:'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbad_list_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgpu_input\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-11-2df20007fc89>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmid_layers\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0mz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput_layer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 94\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mlinear\u001b[0;34m(input, weight, bias)\u001b[0m\n\u001b[1;32m 1751\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhas_torch_function_variadic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1752\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mhandle_torch_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1753\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1754\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1755\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mRuntimeError\u001b[0m: Tensor for 'out' is on CPU, Tensor for argument #1 'self' is on CPU, but expected them to be on GPU (while checking arguments for addmm)" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "n9z0YF7oGZXf" | |
}, | |
"source": [ | |
"## Better Way to Do This" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "EnnyjZp3-Y0a" | |
}, | |
"source": [ | |
"class CorrectListModel(nn.Module):\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
"\n", | |
" input_size = 2\n", | |
" output_size = 3\n", | |
" hidden_size = 16\n", | |
"\n", | |
" self.input_layer = nn.Linear(input_size, hidden_size)\n", | |
" self.input_activation = nn.ReLU()\n", | |
"\n", | |
" # Fairly common when using residual layers\n", | |
" self.mid_layers = []\n", | |
" for _ in range(5):\n", | |
" self.mid_layers.append(nn.Linear(hidden_size, hidden_size))\n", | |
" self.mid_layers.append(nn.ReLU())\n", | |
" self.mid_layers = nn.Sequential(*self.mid_layers)\n", | |
"\n", | |
" self.output_layer = nn.Linear(hidden_size, output_size)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" z = self.input_layer(x)\n", | |
" z = self.input_activation(z)\n", | |
" z = self.mid_layers(z)\n", | |
" out = self.output_layer(z)\n", | |
"\n", | |
" return out" | |
], | |
"execution_count": 14, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "zJt9-A5a-Ys5", | |
"outputId": "062bf1c9-5ca1-4b50-8d9c-0761b70997b6" | |
}, | |
"source": [ | |
"correct_list_model = CorrectListModel()\n", | |
"gpu_input = torch.ones([100, 2], device='cuda')\n", | |
"gpu_correct_list_model = correct_list_model.cuda()\n", | |
"print('Output shape:', correct_list_model(gpu_input).shape)" | |
], | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Output shape: torch.Size([100, 3])\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "lTlaVIitAzxJ" | |
}, | |
"source": [ | |
"# 4. Make Use of Distributions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "wNgAZKGh-bgX", | |
"outputId": "05065134-02a3-4763-9d06-3d6cf70b1921" | |
}, | |
"source": [ | |
"# Setup\n", | |
"example_model = ExampleModel()\n", | |
"input_tensor = torch.rand(5, 2)\n", | |
"output = example_model(input_tensor)\n", | |
"print(output)" | |
], | |
"execution_count": 16, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"tensor([[ 0.1965, 0.0558, -0.2112],\n", | |
" [ 0.2035, 0.0650, -0.2077],\n", | |
" [ 0.2150, 0.0577, -0.2096],\n", | |
" [ 0.1957, 0.0540, -0.2117],\n", | |
" [ 0.2045, 0.0566, -0.2085]], grad_fn=<AddmmBackward>)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "hgyPL1Gn-baX" | |
}, | |
"source": [ | |
"from torch.distributions import Categorical\n", | |
"from torch.distributions.kl import kl_divergence" | |
], | |
"execution_count": 17, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "rCU8D-mO-bTO", | |
"outputId": "f43f4d62-0323-43cb-f3a9-75f8520118e3" | |
}, | |
"source": [ | |
"dist = Categorical(logits=output)\n", | |
"dist" | |
], | |
"execution_count": 18, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"Categorical(logits: torch.Size([5, 3]))" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 18 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "0dCkCQyY-bKq", | |
"outputId": "5ef69a62-f9d3-43f4-c94e-a034298190bf" | |
}, | |
"source": [ | |
"# Get probabilities\n", | |
"dist.probs" | |
], | |
"execution_count": 19, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[0.3946, 0.3428, 0.2625],\n", | |
" [0.3947, 0.3437, 0.2616],\n", | |
" [0.3986, 0.3406, 0.2607],\n", | |
" [0.3947, 0.3426, 0.2627],\n", | |
" [0.3962, 0.3417, 0.2621]], grad_fn=<SoftmaxBackward>)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 19 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "jAnfI0Dt-bEo", | |
"outputId": "141ba422-7b6b-47d4-da6a-0c131bf736ba" | |
}, | |
"source": [ | |
"# Take samples\n", | |
"dist.sample()" | |
], | |
"execution_count": 24, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([0, 1, 0, 0, 2])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 24 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "f-BFWjGo-a9M", | |
"outputId": "318612ba-c2e2-4fc5-d25a-e35863b1cc05" | |
}, | |
"source": [ | |
"# Calculate the KL-Divergence\n", | |
"dist_1 = Categorical(logits=output[0])\n", | |
"dist_2 = Categorical(logits=output[1])\n", | |
"kl_divergence(dist_1, dist_2)" | |
], | |
"execution_count": 25, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor(2.5076e-06, grad_fn=<SumBackward1>)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 25 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "2qZwgkkjAzz-" | |
}, | |
"source": [ | |
"# 5. Use `detach()` On Long-Term Metrics" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "MN-ZgkpX-dCG", | |
"outputId": "361ac300-15c1-48fb-d8d5-a12b837b5e5f" | |
}, | |
"source": [ | |
"# Setup\n", | |
"example_model = ExampleModel()\n", | |
"data_batches = [torch.rand((10, 2)) for _ in range(5)]\n", | |
"criterion = nn.MSELoss(reduce='mean')" | |
], | |
"execution_count": 27, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.\n", | |
" warnings.warn(warning.format(ret))\n" | |
], | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "j5tRrcMXKb4i" | |
}, | |
"source": [ | |
"## Bad Example" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "y0WYtsqx-cwX", | |
"outputId": "297d6d2f-4002-44e4-fa60-254a936709f0" | |
}, | |
"source": [ | |
"losses = []\n", | |
"\n", | |
"# Training loop\n", | |
"for batch in data_batches:\n", | |
" output = example_model(batch)\n", | |
"\n", | |
" target = torch.rand((10, 3))\n", | |
" loss = criterion(output, target)\n", | |
" losses.append(loss)\n", | |
"\n", | |
" # Optimization happens here\n", | |
"\n", | |
"print(losses)" | |
], | |
"execution_count": 28, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[tensor(0.4718, grad_fn=<MseLossBackward>), tensor(0.5156, grad_fn=<MseLossBackward>), tensor(0.6583, grad_fn=<MseLossBackward>), tensor(0.4429, grad_fn=<MseLossBackward>), tensor(0.4133, grad_fn=<MseLossBackward>)]\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "1P685bJgJvzn" | |
}, | |
"source": [ | |
"## Better Example" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "f4zbLSpT-dbu", | |
"outputId": "864b398e-034a-4aff-a4b0-9abad42bc090" | |
}, | |
"source": [ | |
"losses = []\n", | |
"\n", | |
"# Training loop\n", | |
"for batch in data_batches:\n", | |
" output = example_model(batch)\n", | |
"\n", | |
" target = torch.rand((10, 3))\n", | |
" loss = criterion(output, target)\n", | |
" losses.append(loss.item()) # Or `loss.item()`\n", | |
"\n", | |
" # Optimization happens here\n", | |
"\n", | |
"print(losses)" | |
], | |
"execution_count": 31, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[0.5439911484718323, 0.5461570620536804, 0.6738904118537903, 0.5780249834060669, 0.5130327939987183]\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "QkSVHK38Az3F" | |
}, | |
"source": [ | |
"# 6. Trick to Delete a Model from GPU" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "CocJHuhl-e_V" | |
}, | |
"source": [ | |
"import gc" | |
], | |
"execution_count": 32, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "mKlBHS-D-e4F" | |
}, | |
"source": [ | |
"example_model = ExampleModel().cuda()\n", | |
"\n", | |
"del example_model\n", | |
"\n", | |
"gc.collect()\n", | |
"# The model will normally stay on the cache until something takes it's place\n", | |
"torch.cuda.empty_cache()" | |
], | |
"execution_count": 33, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "uagavyWeAz6K" | |
}, | |
"source": [ | |
"# 7. Call `eval()` Before Testing" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "vgkGWVcU-fgp", | |
"outputId": "a9727c7d-194d-4109-fcee-dd664ff83670" | |
}, | |
"source": [ | |
"example_model = ExampleModel()\n", | |
"\n", | |
"# Do training\n", | |
"\n", | |
"example_model.eval()\n", | |
"\n", | |
"# Do testing\n", | |
"\n", | |
"example_model.train()\n", | |
"\n", | |
"# Do training again" | |
], | |
"execution_count": 34, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"ExampleModel(\n", | |
" (input_layer): Linear(in_features=2, out_features=16, bias=True)\n", | |
" (input_activation): ReLU()\n", | |
" (mid_layer): Linear(in_features=16, out_features=16, bias=True)\n", | |
" (mid_activation): ReLU()\n", | |
" (output_layer): Linear(in_features=16, out_features=3, bias=True)\n", | |
")" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 34 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "XnvexQJY-gyH" | |
}, | |
"source": [ | |
"### Affects\n", | |
" - Dropout\n", | |
" - Batch Normalization\n", | |
" - RNNs\n", | |
" - Lazy Variants\n", | |
"\n", | |
"source: https://stackoverflow.com/questions/66534762/which-pytorch-modules-are-affected-by-model-eval-and-model-train" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "c28JegUPEBdl" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment