Last active
October 25, 2020 03:57
-
-
Save neoyipeng2018/45e1978cd5d8668966cb8656ae445567 to your computer and use it in GitHub Desktop.
UnderstandingLayerAndBatchNorm.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "UnderstandingLayerAndBatchNorm.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyMqIKlbquWJ+jbEDH41MoQG", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/neoyipeng2018/45e1978cd5d8668966cb8656ae445567/understandinglayerandbatchnorm.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "b6powbwyVElC" | |
}, | |
"source": [ | |
"# Understanding Layer and Batch Norm by computing from scratch" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "j2SickWgVdLX" | |
}, | |
"source": [ | |
"## Simple example of a mini batch of 8, with 3 features, each with 1 dimension" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "2SgCeUi3fFLd" | |
}, | |
"source": [ | |
"import torch\n", | |
"from torch import nn" | |
], | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "b3o0d2_cVGjd", | |
"outputId": "7903c394-78f1-4fad-c485-db25df3a1792", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 153 | |
} | |
}, | |
"source": [ | |
"input = torch.randn(8, 3)\n", | |
"input" | |
], | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0.3857, 0.6901, -0.5801],\n", | |
" [ 0.4570, 0.6409, 0.4245],\n", | |
" [ 1.8557, -0.9741, 0.6087],\n", | |
" [-1.3705, 0.9062, -1.5359],\n", | |
" [-1.5376, 0.5564, -0.1074],\n", | |
" [-1.6854, -0.1565, -1.2909],\n", | |
" [ 0.3606, -0.4981, 0.3168],\n", | |
" [-0.9445, -1.0355, -1.6735]])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 3 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "kJRcWX9CfkRe" | |
}, | |
"source": [ | |
"## LayerNorm" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "IhnjZeD_WSly" | |
}, | |
"source": [ | |
"### Using torch's `LayerNorm` to get 'ground truths'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "qUS1z90-VTmx", | |
"outputId": "1b6e5c0c-2979-4cab-97cd-7a4f4b8204cb", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 153 | |
} | |
}, | |
"source": [ | |
"m = nn.LayerNorm(input.size()[1:])\n", | |
"output = m(input)\n", | |
"output" | |
], | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0.4071, 0.9693, -1.3764],\n", | |
" [-0.5298, 1.3997, -0.8700],\n", | |
" [ 1.1736, -1.2702, 0.0966],\n", | |
" [-0.6316, 1.4116, -0.7800],\n", | |
" [-1.3445, 1.0521, 0.2924],\n", | |
" [-0.9893, 1.3698, -0.3805],\n", | |
" [ 0.7617, -1.4127, 0.6510],\n", | |
" [ 0.8426, 0.5622, -1.4049]], grad_fn=<NativeLayerNormBackward>)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 4 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "L5l8caUqO0kZ", | |
"outputId": "4b5bddd4-c71e-4bfa-e69d-5ab808ff80ed", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
} | |
}, | |
"source": [ | |
"m.eps" | |
], | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"1e-05" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 5 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "cQO4ULQJX-e9" | |
}, | |
"source": [ | |
"### For layer norm, we are normalizing **across feature dimensions** for each training example in the batch, i.e. dim 1." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "-Z9J6jJ3bKlh", | |
"outputId": "e87b5ee3-d4ff-4669-e74c-5ac6fa47b4b9", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 153 | |
} | |
}, | |
"source": [ | |
"av=input.mean(1).repeat(3,1).T\n", | |
"av" | |
], | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0.1652, 0.1652, 0.1652],\n", | |
" [ 0.5075, 0.5075, 0.5075],\n", | |
" [ 0.4968, 0.4968, 0.4968],\n", | |
" [-0.6667, -0.6667, -0.6667],\n", | |
" [-0.3629, -0.3629, -0.3629],\n", | |
" [-1.0443, -1.0443, -1.0443],\n", | |
" [ 0.0598, 0.0598, 0.0598],\n", | |
" [-1.2178, -1.2178, -1.2178]])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 6 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "KraO6Byqa-py", | |
"outputId": "9b43adc5-358a-466a-ca88-fa9e172fce34", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 153 | |
} | |
}, | |
"source": [ | |
"dev=(torch.sqrt(torch.var(input,1,False)+m.eps)).repeat(3,1).T\n", | |
"dev" | |
], | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[0.5415, 0.5415, 0.5415],\n", | |
" [0.0953, 0.0953, 0.0953],\n", | |
" [1.1580, 1.1580, 1.1580],\n", | |
" [1.1143, 1.1143, 1.1143],\n", | |
" [0.8737, 0.8737, 0.8737],\n", | |
" [0.6480, 0.6480, 0.6480],\n", | |
" [0.3949, 0.3949, 0.3949],\n", | |
" [0.3244, 0.3244, 0.3244]])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 7 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "WPB_6WSsYDPQ", | |
"outputId": "b815ff1a-9443-48c1-9315-7429aadcf0e2", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 153 | |
} | |
}, | |
"source": [ | |
"outputManual=(input-av)/dev\n", | |
"outputManual" | |
], | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0.4071, 0.9693, -1.3764],\n", | |
" [-0.5298, 1.3997, -0.8700],\n", | |
" [ 1.1736, -1.2702, 0.0966],\n", | |
" [-0.6316, 1.4116, -0.7800],\n", | |
" [-1.3445, 1.0521, 0.2924],\n", | |
" [-0.9893, 1.3698, -0.3805],\n", | |
" [ 0.7617, -1.4127, 0.6510],\n", | |
" [ 0.8426, 0.5622, -1.4048]])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 8 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "AxPUN4o_UDeX", | |
"outputId": "e2deec07-92a7-4e30-da28-3d1ae8c49945", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 153 | |
} | |
}, | |
"source": [ | |
"output-outputManual" | |
], | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0.0000e+00, 5.9605e-08, 0.0000e+00],\n", | |
" [-4.7684e-07, 1.5497e-06, -8.3447e-07],\n", | |
" [ 0.0000e+00, 1.1921e-07, 7.4506e-09],\n", | |
" [ 5.9605e-08, -2.3842e-07, 1.1921e-07],\n", | |
" [-1.1921e-07, 0.0000e+00, 2.9802e-08],\n", | |
" [-5.9605e-08, 2.3842e-07, 0.0000e+00],\n", | |
" [-5.9605e-08, 1.1921e-07, 0.0000e+00],\n", | |
" [ 1.4305e-06, 1.0729e-06, -1.5497e-06]], grad_fn=<SubBackward0>)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 9 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "w7t9-d4TfpBN" | |
}, | |
"source": [ | |
"## BatchNorm" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "vYbO0G2mWVlk" | |
}, | |
"source": [ | |
"### Using torch's `BatchNorm1d` get 'ground truths'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "0qYHhVhKSl1o", | |
"outputId": "06b0bbf1-8ba9-4a73-9156-a4c567b628ca", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 153 | |
} | |
}, | |
"source": [ | |
"m = nn.BatchNorm1d(input.size()[1:])\n", | |
"output = m(input)\n", | |
"output" | |
], | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0.5895, 0.9183, -0.1160],\n", | |
" [ 0.6499, 0.8512, 1.0455],\n", | |
" [ 1.8353, -1.3492, 1.2584],\n", | |
" [-0.8989, 1.2126, -1.2212],\n", | |
" [-1.0405, 0.7360, 0.4305],\n", | |
" [-1.1657, -0.2353, -0.9379],\n", | |
" [ 0.5682, -0.7007, 0.9210],\n", | |
" [-0.5378, -1.4329, -1.3803]], grad_fn=<NativeBatchNormBackward>)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 14 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Tx2gO3l9XvA6" | |
}, | |
"source": [ | |
"### For batch norm, we are normalizing **across batch** for each feature, i.e. dim 0." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "WU7YDpVtW1NW", | |
"outputId": "c68b0520-ac20-409e-ba26-6f3f2ae6ae6b", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
} | |
}, | |
"source": [ | |
"av=input.mean(0)\n", | |
"av" | |
], | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([-0.3099, 0.0162, -0.4797])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 15 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "RNAKxj0FXLHj", | |
"outputId": "a66ea1b6-e621-474f-a172-0d0d17328bd9", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
} | |
}, | |
"source": [ | |
"dev=torch.sqrt(torch.var(input,0,False)+m.eps)\n", | |
"dev" | |
], | |
"execution_count": 16, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([1.1799, 0.7339, 0.8649])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 16 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "XIXJear0XI1U", | |
"outputId": "9480fbfc-b8ff-4eb9-ff02-8667e87586c4", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 153 | |
} | |
}, | |
"source": [ | |
"outputManual=(input-av)/dev\n", | |
"outputManual" | |
], | |
"execution_count": 17, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0.5895, 0.9183, -0.1160],\n", | |
" [ 0.6499, 0.8512, 1.0455],\n", | |
" [ 1.8353, -1.3492, 1.2584],\n", | |
" [-0.8989, 1.2126, -1.2212],\n", | |
" [-1.0405, 0.7360, 0.4305],\n", | |
" [-1.1657, -0.2353, -0.9379],\n", | |
" [ 0.5682, -0.7007, 0.9210],\n", | |
" [-0.5378, -1.4329, -1.3803]])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 17 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "KY9MF48TZPhc", | |
"outputId": "0063b2df-9d08-4bfb-aef4-436e9294e400", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 153 | |
} | |
}, | |
"source": [ | |
"output-outputManual" | |
], | |
"execution_count": 18, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[ 5.9605e-08, -5.9605e-08, 0.0000e+00],\n", | |
" [ 5.9605e-08, -5.9605e-08, 0.0000e+00],\n", | |
" [ 1.1921e-07, 1.1921e-07, 0.0000e+00],\n", | |
" [-1.1921e-07, 0.0000e+00, 0.0000e+00],\n", | |
" [ 0.0000e+00, -5.9605e-08, 0.0000e+00],\n", | |
" [-1.1921e-07, 1.4901e-08, 0.0000e+00],\n", | |
" [ 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", | |
" [-5.9605e-08, 1.1921e-07, 1.1921e-07]], grad_fn=<SubBackward0>)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 18 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "UAzKPONa4lfF" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment