Skip to content

Instantly share code, notes, and snippets.

@n-taku
Created March 29, 2020 04:23
Show Gist options
  • Save n-taku/111bbca07651f125e6ef8f9131ac4e47 to your computer and use it in GitHub Desktop.
Save n-taku/111bbca07651f125e6ef8f9131ac4e47 to your computer and use it in GitHub Desktop.
BatchNormalizationのモデルのサンプル
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "CIFAR10BN_model1.ipynb",
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "2uFikVAygWbr",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 340
},
"outputId": "502f9d5f-6351-4203-f45e-6d543c25f213"
},
"source": [
"import torch\n",
"from torchsummary import summary\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.conv1 = nn.Conv2d(3, 64, 5)\n",
" self.pool = nn.MaxPool2d(2, 2)\n",
" self.conv2 = nn.Conv2d(64, 128, 5)\n",
" self.fc1 = nn.Linear(128 * 5 * 5, 120)\n",
" self.fc2 = nn.Linear(120, 10)\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = self.pool(F.relu(x))\n",
" x = self.conv2(x)\n",
" x = self.pool(F.relu(x))\n",
" x = x.view(-1, 128 * 5 * 5)\n",
" x = F.relu(self.fc1(x))\n",
" x = self.fc2(x)\n",
" return x\n",
"\n",
"summary(Net(), (3, 32, 32))"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 64, 28, 28] 4,864\n",
" MaxPool2d-2 [-1, 64, 14, 14] 0\n",
" Conv2d-3 [-1, 128, 10, 10] 204,928\n",
" MaxPool2d-4 [-1, 128, 5, 5] 0\n",
" Linear-5 [-1, 120] 384,120\n",
" Linear-6 [-1, 10] 1,210\n",
"================================================================\n",
"Total params: 595,122\n",
"Trainable params: 595,122\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 0.01\n",
"Forward/backward pass size (MB): 0.60\n",
"Params size (MB): 2.27\n",
"Estimated Total Size (MB): 2.88\n",
"----------------------------------------------------------------\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment