Last active
July 23, 2024 07:21
-
-
Save harusametime/0690ec783a4e45e6f68af19b94954855 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "2e62db91-c5aa-4515-8fd9-60fe9c62d7b4", | |
"metadata": {}, | |
"source": [ | |
"## カラー画像データセット cifar10 の分類\n", | |
"\n", | |
"ここでは [cifar10](https://www.cs.toronto.edu/~kriz/cifar.html) というカラー画像のデータセットを扱います。MNIST の場合と同様に、AWS がホスティングしているレポジトリ ([Registry of Open Data on AWS](https://registry.opendata.aws/fast-ai-imageclas/)) からダウンロードして利用します。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "bc1f648f-f1a0-44bc-8037-02a5b0545eb7", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"!aws s3 cp s3://fast-ai-imageclas/cifar10.tgz . --no-sign-request\n", | |
"!tar -xvzf cifar10.tgz" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "5bb66b27-b484-427f-9f86-4ae2cc4043b4", | |
"metadata": {}, | |
"source": [ | |
"深層学習でデータを利用する場合、すべてのデータをメモリに載せることが難しい場合があります。その場合は、データセットやデータローダーを使って、データを読み込む先だけを指定しておいて、バッチサイズの分だけ逐次読み込む方式を利用します。\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "52aabb4b-9039-4a4d-9987-a16f9060a88d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from torchvision import datasets, transforms\n", | |
"\n", | |
"# 前処理の設定\n", | |
"transform = transforms.Compose([\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n", | |
"])\n", | |
"\n", | |
"# データセットの作成\n", | |
"train_dataset = datasets.ImageFolder(\"cifar10/train\", transform=transform)\n", | |
"test_dataset = datasets.ImageFolder(\"cifar10/test\", transform=transform)\n", | |
"\n", | |
"# データセットローダーの作成\n", | |
"batch_size = 32\n", | |
"train_dataloader = torch.utils.data.DataLoader(train_dataset , batch_size=batch_size, shuffle=True)\n", | |
"test_dataloader = torch.utils.data.DataLoader(train_dataset , batch_size=5, shuffle=True)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "06f2bf6d-3c01-412e-b287-7671558d1113", | |
"metadata": {}, | |
"source": [ | |
"実際に画像を読み込んで表示してみましょう。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "53495c1f-5f0d-4fda-a703-c8c596d982cc", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%matplotlib inline\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"# 1枚の画像を取り出す\n", | |
"images, labels = next(iter(train_dataloader))\n", | |
"image = images[0]\n", | |
"label = labels[0]\n", | |
"\n", | |
"image = (image + 1) / 2 # -1 - 1 を 0 - 1 に変換\n", | |
"image = image.clamp(0, 1) # 値を0-1に丸める\n", | |
"\n", | |
"# Tensorを numpy 配列に変換\n", | |
"image = image.numpy().transpose((1, 2, 0))\n", | |
"\n", | |
"# 画像を表示\n", | |
"plt.imshow(image)\n", | |
"plt.title(f'Label: {label}')\n", | |
"plt.axis('off')\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "2cc8e163-aae0-49d3-83ae-8dd10df11db3", | |
"metadata": {}, | |
"source": [ | |
"## ニューラルネットワークの定義\n", | |
"\n", | |
"ここでは畳み込みニューラルネットワークを利用します。2つの畳み込み層のあとに3つの全結合層を利用します。__init__ の中にレイヤーの定義を記載し、forward の中にそれらを利用した実際の計算処理を書きます。f\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "f9959696-4b83-440a-b7a1-bbb959dcec16", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"\n", | |
"class Net(nn.Module):\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
" self.conv1 = nn.Conv2d(3, 6, 5)\n", | |
" self.pool = nn.MaxPool2d(2, 2)\n", | |
" self.conv2 = nn.Conv2d(6, 16, 5)\n", | |
" self.fc1 = nn.Linear(16 * 5 * 5, 120)\n", | |
" self.fc2 = nn.Linear(120, 84)\n", | |
" self.fc3 = nn.Linear(84, 10)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = self.pool(F.relu(self.conv1(x)))\n", | |
" x = self.pool(F.relu(self.conv2(x)))\n", | |
" x = torch.flatten(x, 1) # flatten all dimensions except batch\n", | |
" x = F.relu(self.fc1(x))\n", | |
" x = F.relu(self.fc2(x))\n", | |
" x = self.fc3(x)\n", | |
" return x\n", | |
"\n", | |
"\n", | |
"net = Net()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "aa734672-20bc-42bf-a7af-2e06805e1dd7", | |
"metadata": {}, | |
"source": [ | |
"分類のためのクロスエントロピーロスと、最適化の SGD という手法を指定します。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "d5fea12e-7122-46e1-b0f1-ecd76ea6ab7f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch.optim as optim\n", | |
"\n", | |
"criterion = nn.CrossEntropyLoss()\n", | |
"optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "578d51a5-d720-40f6-822c-4d34d1489ecb", | |
"metadata": {}, | |
"source": [ | |
"トレーニングを開始します。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "4f1517e8-1542-4e41-953d-07ee274d65d0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"for epoch in range(5): # loop over the dataset multiple times\n", | |
"\n", | |
" running_loss = 0.0\n", | |
" for i, data in enumerate(train_dataloader, 0):\n", | |
" # get the inputs; data is a list of [inputs, labels]\n", | |
" inputs, labels = data\n", | |
"\n", | |
" # zero the parameter gradients\n", | |
" optimizer.zero_grad()\n", | |
"\n", | |
" # forward + backward + optimize\n", | |
" outputs = net(inputs)\n", | |
" loss = criterion(outputs, labels.to(torch.long))\n", | |
" # print(labels)\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
"\n", | |
" # print statistics\n", | |
" running_loss += loss.item()\n", | |
" if i % 200 == 199: # print every 200 mini-batches\n", | |
" print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 200:.3f}')\n", | |
" running_loss = 0.0\n", | |
"\n", | |
"print('Finished Training')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "2b40d1d9-81f6-4dc3-bade-d736f055046d", | |
"metadata": {}, | |
"source": [ | |
"テストデータを入力してどこまで正解できるかみてみましょう。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "aed11b62-7c88-4fd3-98a9-69d13cf8d9ba", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torchvision\n", | |
"dataiter = iter(test_dataloader)\n", | |
"images, labels = next(dataiter)\n", | |
"outputs = net(images)\n", | |
"_, predicted = torch.max(outputs.data, 1)\n", | |
"\n", | |
"# print images\n", | |
"images = (images + 1) / 2 # -1 - 1 を 0 - 1 に変換\n", | |
"images = images.clamp(0, 1) # 値を0-1に丸める\n", | |
"plt.imshow(torchvision.utils.make_grid(images).numpy().transpose([1,2,0]))\n", | |
"classes = ('plane', 'car', 'bird', 'cat',\n", | |
" 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n", | |
"print('正解: ', ' '.join(f'{classes[labels[j]]:12s}' for j in range(5)))\n", | |
"print('予測: ', ' '.join(f'{classes[predicted[j]]:12s}' for j in range(5)))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "219d8f7d-1450-41ef-b3ce-37e12abad9dc", | |
"metadata": {}, | |
"source": [ | |
"## ファインチューニング\n", | |
"\n", | |
"すでに用意されているモデルをカスタマイズして、独自のデータにあわせて学習します。すでに用意されているモデルは、大規模な画像で学習されているので、画像の理解に関する基本的な能力があるとされています。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "aa7cc23a-7974-4988-b065-6d33342e9790", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torchvision.models as models\n", | |
"import torchvision.transforms as transforms\n", | |
"\n", | |
"# ハイパーパラメータ\n", | |
"num_epochs = 1\n", | |
"batch_size = 16\n", | |
"learning_rate = 0.001\n", | |
"\n", | |
"num_classes = 10\n", | |
"\n", | |
"# データセットの準備\n", | |
"mobilenet_transform = transforms.Compose([\n", | |
" transforms.Resize(256),\n", | |
" transforms.CenterCrop(224),\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", | |
"])\n", | |
"mobile_train_dataset = datasets.ImageFolder('cifar10/train', transform=mobilenet_transform)\n", | |
"mobile_train_dataloader = torch.utils.data.DataLoader(mobile_train_dataset, batch_size=batch_size, shuffle=True)\n", | |
"\n", | |
"mobile_test_dataset = datasets.ImageFolder('cifar10/test', transform=mobilenet_transform)\n", | |
"mobile_test_dataloader = torch.utils.data.DataLoader(mobile_test_dataset, batch_size=5, shuffle=True)\n", | |
"\n", | |
"# モデルの準備\n", | |
"model = models.mobilenet_v3_large(pretrained=True)\n", | |
"num_ftrs = model.classifier[-1].in_features\n", | |
"model.classifier[-1] = nn.Linear(num_ftrs, num_classes)\n", | |
"\n", | |
"# 損失関数と最適化手法\n", | |
"criterion = nn.CrossEntropyLoss()\n", | |
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", | |
"\n", | |
"# 学習ループ\n", | |
"running_loss = 0.0\n", | |
"for epoch in range(num_epochs):\n", | |
" for i, data in enumerate(mobile_train_dataloader,0):\n", | |
" inputs, labels = data\n", | |
" # 前向き計算\n", | |
" outputs = model(inputs)\n", | |
" loss = criterion(outputs, labels)\n", | |
" \n", | |
" # 逆伝播\n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
"\n", | |
" # print statistics\n", | |
" running_loss += loss.item()\n", | |
" if i % 20 == 19: # print every 200 mini-batches\n", | |
" print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 20:.3f}')\n", | |
" running_loss = 0.0\n", | |
"\n", | |
"# モデルの保存\n", | |
"torch.save(model.state_dict(), 'mobilenet_model.pth')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "f2616b00-8653-4e75-a2a5-a3ff77efad21", | |
"metadata": {}, | |
"source": [ | |
"テストデータを入力してどこまで正解できるかみてみましょう。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "5c99d2c8-e7bf-4cd2-8bf0-2c3a0a11dfc7", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torchvision\n", | |
"dataiter = iter(mobile_test_dataloader)\n", | |
"images, labels = next(dataiter)\n", | |
"outputs = model(images)\n", | |
"_, predicted = torch.max(outputs.data, 1)\n", | |
"\n", | |
"# print images\n", | |
"images = (images + 1) / 2 # -1 - 1 を 0 - 1 に変換\n", | |
"images = images.clamp(0, 1) # 値を0-1に丸める\n", | |
"plt.imshow(torchvision.utils.make_grid(images).numpy().transpose([1,2,0]))\n", | |
"classes = ('plane', 'car', 'bird', 'cat',\n", | |
" 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n", | |
"print('正解: ', ' '.join(f'{classes[labels[j]]:12s}' for j in range(5)))\n", | |
"print('予測: ', ' '.join(f'{classes[predicted[j]]:12s}' for j in range(5)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "488b04cf-0a06-4122-9e64-53a8b22cda64", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.14" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment