Skip to content

Instantly share code, notes, and snippets.

@neoyipeng2018
Last active October 25, 2020 03:57
Show Gist options
  • Save neoyipeng2018/45e1978cd5d8668966cb8656ae445567 to your computer and use it in GitHub Desktop.
Save neoyipeng2018/45e1978cd5d8668966cb8656ae445567 to your computer and use it in GitHub Desktop.
UnderstandingLayerAndBatchNorm.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "UnderstandingLayerAndBatchNorm.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyMqIKlbquWJ+jbEDH41MoQG",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/neoyipeng2018/45e1978cd5d8668966cb8656ae445567/understandinglayerandbatchnorm.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b6powbwyVElC"
},
"source": [
"# Understanding Layer and Batch Norm by computing from scratch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j2SickWgVdLX"
},
"source": [
"## Simple example of a mini batch of 8, with 3 features, each with 1 dimension"
]
},
{
"cell_type": "code",
"metadata": {
"id": "2SgCeUi3fFLd"
},
"source": [
"import torch\n",
"from torch import nn"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "b3o0d2_cVGjd",
"outputId": "7903c394-78f1-4fad-c485-db25df3a1792",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 153
}
},
"source": [
"input = torch.randn(8, 3)\n",
"input"
],
"execution_count": 3,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 0.3857, 0.6901, -0.5801],\n",
" [ 0.4570, 0.6409, 0.4245],\n",
" [ 1.8557, -0.9741, 0.6087],\n",
" [-1.3705, 0.9062, -1.5359],\n",
" [-1.5376, 0.5564, -0.1074],\n",
" [-1.6854, -0.1565, -1.2909],\n",
" [ 0.3606, -0.4981, 0.3168],\n",
" [-0.9445, -1.0355, -1.6735]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 3
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kJRcWX9CfkRe"
},
"source": [
"## LayerNorm"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IhnjZeD_WSly"
},
"source": [
"### Using torch's `LayerNorm` to get 'ground truths'"
]
},
{
"cell_type": "code",
"metadata": {
"id": "qUS1z90-VTmx",
"outputId": "1b6e5c0c-2979-4cab-97cd-7a4f4b8204cb",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 153
}
},
"source": [
"m = nn.LayerNorm(input.size()[1:])\n",
"output = m(input)\n",
"output"
],
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 0.4071, 0.9693, -1.3764],\n",
" [-0.5298, 1.3997, -0.8700],\n",
" [ 1.1736, -1.2702, 0.0966],\n",
" [-0.6316, 1.4116, -0.7800],\n",
" [-1.3445, 1.0521, 0.2924],\n",
" [-0.9893, 1.3698, -0.3805],\n",
" [ 0.7617, -1.4127, 0.6510],\n",
" [ 0.8426, 0.5622, -1.4049]], grad_fn=<NativeLayerNormBackward>)"
]
},
"metadata": {
"tags": []
},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "L5l8caUqO0kZ",
"outputId": "4b5bddd4-c71e-4bfa-e69d-5ab808ff80ed",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"m.eps"
],
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"1e-05"
]
},
"metadata": {
"tags": []
},
"execution_count": 5
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cQO4ULQJX-e9"
},
"source": [
"### For layer norm, we are normalizing **across feature dimensions** for each training example in the batch, i.e. dim 1."
]
},
{
"cell_type": "code",
"metadata": {
"id": "-Z9J6jJ3bKlh",
"outputId": "e87b5ee3-d4ff-4669-e74c-5ac6fa47b4b9",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 153
}
},
"source": [
"av=input.mean(1).repeat(3,1).T\n",
"av"
],
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 0.1652, 0.1652, 0.1652],\n",
" [ 0.5075, 0.5075, 0.5075],\n",
" [ 0.4968, 0.4968, 0.4968],\n",
" [-0.6667, -0.6667, -0.6667],\n",
" [-0.3629, -0.3629, -0.3629],\n",
" [-1.0443, -1.0443, -1.0443],\n",
" [ 0.0598, 0.0598, 0.0598],\n",
" [-1.2178, -1.2178, -1.2178]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "KraO6Byqa-py",
"outputId": "9b43adc5-358a-466a-ca88-fa9e172fce34",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 153
}
},
"source": [
"dev=(torch.sqrt(torch.var(input,1,False)+m.eps)).repeat(3,1).T\n",
"dev"
],
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[0.5415, 0.5415, 0.5415],\n",
" [0.0953, 0.0953, 0.0953],\n",
" [1.1580, 1.1580, 1.1580],\n",
" [1.1143, 1.1143, 1.1143],\n",
" [0.8737, 0.8737, 0.8737],\n",
" [0.6480, 0.6480, 0.6480],\n",
" [0.3949, 0.3949, 0.3949],\n",
" [0.3244, 0.3244, 0.3244]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "WPB_6WSsYDPQ",
"outputId": "b815ff1a-9443-48c1-9315-7429aadcf0e2",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 153
}
},
"source": [
"outputManual=(input-av)/dev\n",
"outputManual"
],
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 0.4071, 0.9693, -1.3764],\n",
" [-0.5298, 1.3997, -0.8700],\n",
" [ 1.1736, -1.2702, 0.0966],\n",
" [-0.6316, 1.4116, -0.7800],\n",
" [-1.3445, 1.0521, 0.2924],\n",
" [-0.9893, 1.3698, -0.3805],\n",
" [ 0.7617, -1.4127, 0.6510],\n",
" [ 0.8426, 0.5622, -1.4048]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "AxPUN4o_UDeX",
"outputId": "e2deec07-92a7-4e30-da28-3d1ae8c49945",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 153
}
},
"source": [
"output-outputManual"
],
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 0.0000e+00, 5.9605e-08, 0.0000e+00],\n",
" [-4.7684e-07, 1.5497e-06, -8.3447e-07],\n",
" [ 0.0000e+00, 1.1921e-07, 7.4506e-09],\n",
" [ 5.9605e-08, -2.3842e-07, 1.1921e-07],\n",
" [-1.1921e-07, 0.0000e+00, 2.9802e-08],\n",
" [-5.9605e-08, 2.3842e-07, 0.0000e+00],\n",
" [-5.9605e-08, 1.1921e-07, 0.0000e+00],\n",
" [ 1.4305e-06, 1.0729e-06, -1.5497e-06]], grad_fn=<SubBackward0>)"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "w7t9-d4TfpBN"
},
"source": [
"## BatchNorm"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vYbO0G2mWVlk"
},
"source": [
"### Using torch's `BatchNorm1d` get 'ground truths'"
]
},
{
"cell_type": "code",
"metadata": {
"id": "0qYHhVhKSl1o",
"outputId": "06b0bbf1-8ba9-4a73-9156-a4c567b628ca",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 153
}
},
"source": [
"m = nn.BatchNorm1d(input.size()[1:])\n",
"output = m(input)\n",
"output"
],
"execution_count": 14,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 0.5895, 0.9183, -0.1160],\n",
" [ 0.6499, 0.8512, 1.0455],\n",
" [ 1.8353, -1.3492, 1.2584],\n",
" [-0.8989, 1.2126, -1.2212],\n",
" [-1.0405, 0.7360, 0.4305],\n",
" [-1.1657, -0.2353, -0.9379],\n",
" [ 0.5682, -0.7007, 0.9210],\n",
" [-0.5378, -1.4329, -1.3803]], grad_fn=<NativeBatchNormBackward>)"
]
},
"metadata": {
"tags": []
},
"execution_count": 14
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Tx2gO3l9XvA6"
},
"source": [
"### For batch norm, we are normalizing **across batch** for each feature, i.e. dim 0."
]
},
{
"cell_type": "code",
"metadata": {
"id": "WU7YDpVtW1NW",
"outputId": "c68b0520-ac20-409e-ba26-6f3f2ae6ae6b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"av=input.mean(0)\n",
"av"
],
"execution_count": 15,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([-0.3099, 0.0162, -0.4797])"
]
},
"metadata": {
"tags": []
},
"execution_count": 15
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "RNAKxj0FXLHj",
"outputId": "a66ea1b6-e621-474f-a172-0d0d17328bd9",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"dev=torch.sqrt(torch.var(input,0,False)+m.eps)\n",
"dev"
],
"execution_count": 16,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([1.1799, 0.7339, 0.8649])"
]
},
"metadata": {
"tags": []
},
"execution_count": 16
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "XIXJear0XI1U",
"outputId": "9480fbfc-b8ff-4eb9-ff02-8667e87586c4",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 153
}
},
"source": [
"outputManual=(input-av)/dev\n",
"outputManual"
],
"execution_count": 17,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 0.5895, 0.9183, -0.1160],\n",
" [ 0.6499, 0.8512, 1.0455],\n",
" [ 1.8353, -1.3492, 1.2584],\n",
" [-0.8989, 1.2126, -1.2212],\n",
" [-1.0405, 0.7360, 0.4305],\n",
" [-1.1657, -0.2353, -0.9379],\n",
" [ 0.5682, -0.7007, 0.9210],\n",
" [-0.5378, -1.4329, -1.3803]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 17
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "KY9MF48TZPhc",
"outputId": "0063b2df-9d08-4bfb-aef4-436e9294e400",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 153
}
},
"source": [
"output-outputManual"
],
"execution_count": 18,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 5.9605e-08, -5.9605e-08, 0.0000e+00],\n",
" [ 5.9605e-08, -5.9605e-08, 0.0000e+00],\n",
" [ 1.1921e-07, 1.1921e-07, 0.0000e+00],\n",
" [-1.1921e-07, 0.0000e+00, 0.0000e+00],\n",
" [ 0.0000e+00, -5.9605e-08, 0.0000e+00],\n",
" [-1.1921e-07, 1.4901e-08, 0.0000e+00],\n",
" [ 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
" [-5.9605e-08, 1.1921e-07, 1.1921e-07]], grad_fn=<SubBackward0>)"
]
},
"metadata": {
"tags": []
},
"execution_count": 18
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "UAzKPONa4lfF"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment