Created
April 12, 2020 17:41
-
-
Save n-taku/862d3aec242af8f53a6e93a0472c5666 to your computer and use it in GitHub Desktop.
CIFAR10でGAPを使ったSample
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "CIFAR10GAP.ipynb", | |
"provenance": [], | |
"collapsed_sections": [] | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3En3Msax-XDq", | |
"colab_type": "code", | |
"outputId": "35164ced-a305-4d9f-895f-3f9c0ecefdcc", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
} | |
}, | |
"source": [ | |
"import torch\n", | |
"import torchvision\n", | |
"import torch.nn as nn\n", | |
"import torch.optim as optim\n", | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"import pickle\n", | |
"from torchsummary import summary\n", | |
"import torch.nn.functional as F\n", | |
"\n", | |
"BATCH_SIZE = 100\n", | |
"EPOCH = 20\n", | |
"PATH = \"Dataset\"\n", | |
"\n", | |
"transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5,), (0.5,))])\n", | |
"\n", | |
"trainset = torchvision.datasets.CIFAR10(root = PATH, train = True, download = True, transform = transform)\n", | |
"trainloader = torch.utils.data.DataLoader(trainset, batch_size = BATCH_SIZE,\n", | |
" shuffle = True, num_workers = 2)\n", | |
"\n", | |
"testset = torchvision.datasets.CIFAR10(root = PATH, train = False, download = True, transform = transform)\n", | |
"testloader = torch.utils.data.DataLoader(testset, batch_size = BATCH_SIZE,\n", | |
" shuffle = False, num_workers = 2)\n", | |
"\n", | |
"class Net(nn.Module):\n", | |
" def __init__(self):\n", | |
" super(Net, self).__init__()\n", | |
" self.conv1 = nn.Conv2d(3, 16, 5)\n", | |
" self.conv2 = nn.Conv2d(16, 32, 5)\n", | |
" self.conv3 = nn.Conv2d(32, 64, 5)\n", | |
" self.conv4 = nn.Conv2d(64, 10, 5)\n", | |
" self.avgpool = torch.nn.AdaptiveAvgPool2d((1,1))\n", | |
" def forward(self, x):\n", | |
" x = F.relu(self.conv1(x))\n", | |
" x = F.relu(self.conv2(x))\n", | |
" x = F.relu(self.conv3(x))\n", | |
" x = F.relu(self.conv4(x))\n", | |
" x = self.avgpool(x)\n", | |
" x = torch.flatten(x, 1)\n", | |
" return x\n", | |
"\n", | |
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | |
"net = Net()\n", | |
"net = net.to(device)\n", | |
"criterion = nn.CrossEntropyLoss()\n", | |
"optimizer = optim.Adam(net.parameters())\n", | |
"\n", | |
"train_loss_value=[] #trainingのlossを保持するlist\n", | |
"train_acc_value=[] #trainingのaccuracyを保持するlist\n", | |
"test_loss_value=[] #testのlossを保持するlist\n", | |
"test_acc_value=[] #testのaccuracyを保持するlist \n", | |
"\n", | |
"summary(net, (3, 32, 32))\n", | |
"\n", | |
"for epoch in range(EPOCH):\n", | |
" print('epoch', epoch+1) #epoch数の出力\n", | |
" for (inputs, labels) in trainloader:\n", | |
" inputs, labels = inputs.to(device), labels.to(device)\n", | |
" optimizer.zero_grad()\n", | |
" outputs = net(inputs)\n", | |
" loss = criterion(outputs, labels)\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
"\n", | |
" sum_loss = 0.0 #lossの合計\n", | |
" sum_correct = 0 #正解率の合計\n", | |
" sum_total = 0 #dataの数の合計\n", | |
"\n", | |
" #train dataを使ってテストをする(パラメータ更新がないようになっている)\n", | |
" for (inputs, labels) in trainloader:\n", | |
" inputs, labels = inputs.to(device), labels.to(device)\n", | |
" optimizer.zero_grad()\n", | |
" outputs = net(inputs)\n", | |
" loss = criterion(outputs, labels)\n", | |
" #lossを足していく\n", | |
" sum_loss += loss.item()\n", | |
" #出力の最大値の添字(予想位置)を取得\n", | |
" _, predicted = outputs.max(1)\n", | |
" #labelの数を足していくことでデータの総和を取る \n", | |
" sum_total += labels.size(0)\n", | |
" #予想位置と実際の正解を比べ,正解している数だけ足す\n", | |
" sum_correct += (predicted == labels).sum().item()\n", | |
" \n", | |
" #lossとaccuracy出力\n", | |
" print(\"train mean loss={}, accuracy={}\"\n", | |
" .format(sum_loss*BATCH_SIZE/len(trainloader.dataset), float(sum_correct/sum_total)))\n", | |
" #traindataのlossをグラフ描画のためにlistに保持\n", | |
" train_loss_value.append(sum_loss*BATCH_SIZE/len(trainloader.dataset))\n", | |
" #traindataのaccuracyをグラフ描画のためにlistに保持\n", | |
" train_acc_value.append(float(sum_correct/sum_total))\n", | |
"\n", | |
" sum_loss = 0.0\n", | |
" sum_correct = 0\n", | |
" sum_total = 0\n", | |
"\n", | |
" #test dataを使ってテストをする\n", | |
" for (inputs, labels) in testloader:\n", | |
" inputs, labels = inputs.to(device), labels.to(device)\n", | |
" optimizer.zero_grad()\n", | |
" outputs = net(inputs)\n", | |
" loss = criterion(outputs, labels)\n", | |
" sum_loss += loss.item()\n", | |
" _, predicted = outputs.max(1)\n", | |
" sum_total += labels.size(0)\n", | |
" sum_correct += (predicted == labels).sum().item()\n", | |
" print(\"test mean loss={}, accuracy={}\"\n", | |
" .format(sum_loss*BATCH_SIZE/len(testloader.dataset), float(sum_correct/sum_total)))\n", | |
" test_loss_value.append(sum_loss*BATCH_SIZE/len(testloader.dataset))\n", | |
" test_acc_value.append(float(sum_correct/sum_total))\n", | |
"\n", | |
"#グラフ\n", | |
"fig, (axL, axR) = plt.subplots(ncols=2, figsize=(12,6))\n", | |
"\n", | |
"#損失グラフ描画\n", | |
"axL.plot(range(EPOCH), train_loss_value)\n", | |
"axL.plot(range(EPOCH), test_loss_value, c='#00ff00')\n", | |
"axL.set_xlabel('EPOCH')\n", | |
"axL.set_ylabel('LOSS')\n", | |
"axL.legend(['train loss', 'test loss'])\n", | |
"axL.set_title('loss')\n", | |
"\n", | |
"#正答率グラフ描画\n", | |
"axR.plot(range(EPOCH), train_acc_value)\n", | |
"axR.plot(range(EPOCH), test_acc_value, c='#00ff00')\n", | |
"axR.set_xlabel('EPOCH')\n", | |
"axR.set_ylabel('ACCURACY')\n", | |
"axR.legend(['train acc', 'test acc'])\n", | |
"axR.set_title('accuracy')\n", | |
"\n", | |
"fig.savefig(\"loss_accuracy_image.png\")\n", | |
"fig.show()" | |
], | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Files already downloaded and verified\n", | |
"Files already downloaded and verified\n", | |
"----------------------------------------------------------------\n", | |
" Layer (type) Output Shape Param #\n", | |
"================================================================\n", | |
" Conv2d-1 [-1, 16, 28, 28] 1,216\n", | |
" Conv2d-2 [-1, 32, 24, 24] 12,832\n", | |
" Conv2d-3 [-1, 64, 20, 20] 51,264\n", | |
" Conv2d-4 [-1, 10, 16, 16] 16,010\n", | |
" AdaptiveAvgPool2d-5 [-1, 10, 1, 1] 0\n", | |
"================================================================\n", | |
"Total params: 81,322\n", | |
"Trainable params: 81,322\n", | |
"Non-trainable params: 0\n", | |
"----------------------------------------------------------------\n", | |
"Input size (MB): 0.01\n", | |
"Forward/backward pass size (MB): 0.45\n", | |
"Params size (MB): 0.31\n", | |
"Estimated Total Size (MB): 0.77\n", | |
"----------------------------------------------------------------\n", | |
"epoch 1\n", | |
"train mean loss=1.8753954305648803, accuracy=0.34898\n", | |
"test mean loss=1.876963770389557, accuracy=0.351\n", | |
"epoch 2\n", | |
"train mean loss=1.7881581859588622, accuracy=0.39088\n", | |
"test mean loss=1.8026284503936767, accuracy=0.3917\n", | |
"epoch 3\n", | |
"train mean loss=1.6759743094444275, accuracy=0.4397\n", | |
"test mean loss=1.6939589202404022, accuracy=0.4349\n", | |
"epoch 4\n", | |
"train mean loss=1.6510359325408936, accuracy=0.44444\n", | |
"test mean loss=1.6731702125072478, accuracy=0.4396\n", | |
"epoch 5\n", | |
"train mean loss=1.6047766225337983, accuracy=0.46894\n", | |
"test mean loss=1.6281603837013245, accuracy=0.4618\n", | |
"epoch 6\n", | |
"train mean loss=1.5690029315948486, accuracy=0.48236\n", | |
"test mean loss=1.5994542968273162, accuracy=0.4708\n", | |
"epoch 7\n", | |
"train mean loss=1.5360382354259492, accuracy=0.49398\n", | |
"test mean loss=1.574644582271576, accuracy=0.4839\n", | |
"epoch 8\n", | |
"train mean loss=1.5160074324607848, accuracy=0.49464\n", | |
"test mean loss=1.5557704830169679, accuracy=0.4859\n", | |
"epoch 9\n", | |
"train mean loss=1.4975026261806488, accuracy=0.50384\n", | |
"test mean loss=1.5442944002151489, accuracy=0.4925\n", | |
"epoch 10\n", | |
"train mean loss=1.5049279425144195, accuracy=0.4997\n", | |
"test mean loss=1.555120862722397, accuracy=0.4868\n", | |
"epoch 11\n", | |
"train mean loss=1.4826647455692292, accuracy=0.50918\n", | |
"test mean loss=1.5392323923110962, accuracy=0.4938\n", | |
"epoch 12\n", | |
"train mean loss=1.460705751657486, accuracy=0.51118\n", | |
"test mean loss=1.5212853717803956, accuracy=0.4949\n", | |
"epoch 13\n", | |
"train mean loss=1.4437483689785005, accuracy=0.51802\n", | |
"test mean loss=1.512642605304718, accuracy=0.5038\n", | |
"epoch 14\n", | |
"train mean loss=1.4193825724124909, accuracy=0.52722\n", | |
"test mean loss=1.482859982252121, accuracy=0.5088\n", | |
"epoch 15\n", | |
"train mean loss=1.4094636778831482, accuracy=0.5278\n", | |
"test mean loss=1.4785270190238953, accuracy=0.5053\n", | |
"epoch 16\n", | |
"train mean loss=1.3860473586320876, accuracy=0.5379\n", | |
"test mean loss=1.4659891474246978, accuracy=0.5155\n", | |
"epoch 17\n", | |
"train mean loss=1.3771601893901826, accuracy=0.54046\n", | |
"test mean loss=1.462211995124817, accuracy=0.5146\n", | |
"epoch 18\n", | |
"train mean loss=1.3792241594791412, accuracy=0.53896\n", | |
"test mean loss=1.4691423952579499, accuracy=0.5149\n", | |
"epoch 19\n", | |
"train mean loss=1.3409191426038742, accuracy=0.5542\n", | |
"test mean loss=1.436096396446228, accuracy=0.5248\n", | |
"epoch 20\n", | |
"train mean loss=1.3407690443992615, accuracy=0.55284\n", | |
"test mean loss=1.4437957155704497, accuracy=0.5233\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 864x432 with 2 Axes>" | |
] | |
}, | |
"metadata": { | |
"tags": [], | |
"needs_background": "light" | |
} | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment