Skip to content

Instantly share code, notes, and snippets.

@harusametime
Last active July 23, 2024 07:21
Show Gist options
  • Save harusametime/0690ec783a4e45e6f68af19b94954855 to your computer and use it in GitHub Desktop.
Save harusametime/0690ec783a4e45e6f68af19b94954855 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "2e62db91-c5aa-4515-8fd9-60fe9c62d7b4",
"metadata": {},
"source": [
"## カラー画像データセット cifar10 の分類\n",
"\n",
"ここでは [cifar10](https://www.cs.toronto.edu/~kriz/cifar.html) というカラー画像のデータセットを扱います。MNIST の場合と同様に、AWS がホスティングしているレポジトリ ([Registry of Open Data on AWS](https://registry.opendata.aws/fast-ai-imageclas/)) からダウンロードして利用します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bc1f648f-f1a0-44bc-8037-02a5b0545eb7",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"!aws s3 cp s3://fast-ai-imageclas/cifar10.tgz . --no-sign-request\n",
"!tar -xvzf cifar10.tgz"
]
},
{
"cell_type": "markdown",
"id": "5bb66b27-b484-427f-9f86-4ae2cc4043b4",
"metadata": {},
"source": [
"深層学習でデータを利用する場合、すべてのデータをメモリに載せることが難しい場合があります。その場合は、データセットやデータローダーを使って、データを読み込む先だけを指定しておいて、バッチサイズの分だけ逐次読み込む方式を利用します。\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "52aabb4b-9039-4a4d-9987-a16f9060a88d",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torchvision import datasets, transforms\n",
"\n",
"# 前処理の設定\n",
"transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
"])\n",
"\n",
"# データセットの作成\n",
"train_dataset = datasets.ImageFolder(\"cifar10/train\", transform=transform)\n",
"test_dataset = datasets.ImageFolder(\"cifar10/test\", transform=transform)\n",
"\n",
"# データセットローダーの作成\n",
"batch_size = 32\n",
"train_dataloader = torch.utils.data.DataLoader(train_dataset , batch_size=batch_size, shuffle=True)\n",
"test_dataloader = torch.utils.data.DataLoader(train_dataset , batch_size=5, shuffle=True)"
]
},
{
"cell_type": "markdown",
"id": "06f2bf6d-3c01-412e-b287-7671558d1113",
"metadata": {},
"source": [
"実際に画像を読み込んで表示してみましょう。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53495c1f-5f0d-4fda-a703-c8c596d982cc",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# 1枚の画像を取り出す\n",
"images, labels = next(iter(train_dataloader))\n",
"image = images[0]\n",
"label = labels[0]\n",
"\n",
"image = (image + 1) / 2 # -1 - 1 を 0 - 1 に変換\n",
"image = image.clamp(0, 1) # 値を0-1に丸める\n",
"\n",
"# Tensorを numpy 配列に変換\n",
"image = image.numpy().transpose((1, 2, 0))\n",
"\n",
"# 画像を表示\n",
"plt.imshow(image)\n",
"plt.title(f'Label: {label}')\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "2cc8e163-aae0-49d3-83ae-8dd10df11db3",
"metadata": {},
"source": [
"## ニューラルネットワークの定義\n",
"\n",
"ここでは畳み込みニューラルネットワークを利用します。2つの畳み込み層のあとに3つの全結合層を利用します。__init__ の中にレイヤーの定義を記載し、forward の中にそれらを利用した実際の計算処理を書きます。f\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f9959696-4b83-440a-b7a1-bbb959dcec16",
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv1 = nn.Conv2d(3, 6, 5)\n",
" self.pool = nn.MaxPool2d(2, 2)\n",
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
" self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
" self.fc2 = nn.Linear(120, 84)\n",
" self.fc3 = nn.Linear(84, 10)\n",
"\n",
" def forward(self, x):\n",
" x = self.pool(F.relu(self.conv1(x)))\n",
" x = self.pool(F.relu(self.conv2(x)))\n",
" x = torch.flatten(x, 1) # flatten all dimensions except batch\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
" return x\n",
"\n",
"\n",
"net = Net()"
]
},
{
"cell_type": "markdown",
"id": "aa734672-20bc-42bf-a7af-2e06805e1dd7",
"metadata": {},
"source": [
"分類のためのクロスエントロピーロスと、最適化の SGD という手法を指定します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5fea12e-7122-46e1-b0f1-ecd76ea6ab7f",
"metadata": {},
"outputs": [],
"source": [
"import torch.optim as optim\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)"
]
},
{
"cell_type": "markdown",
"id": "578d51a5-d720-40f6-822c-4d34d1489ecb",
"metadata": {},
"source": [
"トレーニングを開始します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4f1517e8-1542-4e41-953d-07ee274d65d0",
"metadata": {},
"outputs": [],
"source": [
"for epoch in range(5): # loop over the dataset multiple times\n",
"\n",
" running_loss = 0.0\n",
" for i, data in enumerate(train_dataloader, 0):\n",
" # get the inputs; data is a list of [inputs, labels]\n",
" inputs, labels = data\n",
"\n",
" # zero the parameter gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # forward + backward + optimize\n",
" outputs = net(inputs)\n",
" loss = criterion(outputs, labels.to(torch.long))\n",
" # print(labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" # print statistics\n",
" running_loss += loss.item()\n",
" if i % 200 == 199: # print every 200 mini-batches\n",
" print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 200:.3f}')\n",
" running_loss = 0.0\n",
"\n",
"print('Finished Training')"
]
},
{
"cell_type": "markdown",
"id": "2b40d1d9-81f6-4dc3-bade-d736f055046d",
"metadata": {},
"source": [
"テストデータを入力してどこまで正解できるかみてみましょう。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aed11b62-7c88-4fd3-98a9-69d13cf8d9ba",
"metadata": {},
"outputs": [],
"source": [
"import torchvision\n",
"dataiter = iter(test_dataloader)\n",
"images, labels = next(dataiter)\n",
"outputs = net(images)\n",
"_, predicted = torch.max(outputs.data, 1)\n",
"\n",
"# print images\n",
"images = (images + 1) / 2 # -1 - 1 を 0 - 1 に変換\n",
"images = images.clamp(0, 1) # 値を0-1に丸める\n",
"plt.imshow(torchvision.utils.make_grid(images).numpy().transpose([1,2,0]))\n",
"classes = ('plane', 'car', 'bird', 'cat',\n",
" 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
"print('正解: ', ' '.join(f'{classes[labels[j]]:12s}' for j in range(5)))\n",
"print('予測: ', ' '.join(f'{classes[predicted[j]]:12s}' for j in range(5)))"
]
},
{
"cell_type": "markdown",
"id": "219d8f7d-1450-41ef-b3ce-37e12abad9dc",
"metadata": {},
"source": [
"## ファインチューニング\n",
"\n",
"すでに用意されているモデルをカスタマイズして、独自のデータにあわせて学習します。すでに用意されているモデルは、大規模な画像で学習されているので、画像の理解に関する基本的な能力があるとされています。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aa7cc23a-7974-4988-b065-6d33342e9790",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision.models as models\n",
"import torchvision.transforms as transforms\n",
"\n",
"# ハイパーパラメータ\n",
"num_epochs = 1\n",
"batch_size = 16\n",
"learning_rate = 0.001\n",
"\n",
"num_classes = 10\n",
"\n",
"# データセットの準備\n",
"mobilenet_transform = transforms.Compose([\n",
" transforms.Resize(256),\n",
" transforms.CenterCrop(224),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
"])\n",
"mobile_train_dataset = datasets.ImageFolder('cifar10/train', transform=mobilenet_transform)\n",
"mobile_train_dataloader = torch.utils.data.DataLoader(mobile_train_dataset, batch_size=batch_size, shuffle=True)\n",
"\n",
"mobile_test_dataset = datasets.ImageFolder('cifar10/test', transform=mobilenet_transform)\n",
"mobile_test_dataloader = torch.utils.data.DataLoader(mobile_test_dataset, batch_size=5, shuffle=True)\n",
"\n",
"# モデルの準備\n",
"model = models.mobilenet_v3_large(pretrained=True)\n",
"num_ftrs = model.classifier[-1].in_features\n",
"model.classifier[-1] = nn.Linear(num_ftrs, num_classes)\n",
"\n",
"# 損失関数と最適化手法\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
"\n",
"# 学習ループ\n",
"running_loss = 0.0\n",
"for epoch in range(num_epochs):\n",
" for i, data in enumerate(mobile_train_dataloader,0):\n",
" inputs, labels = data\n",
" # 前向き計算\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs, labels)\n",
" \n",
" # 逆伝播\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" # print statistics\n",
" running_loss += loss.item()\n",
" if i % 20 == 19: # print every 200 mini-batches\n",
" print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 20:.3f}')\n",
" running_loss = 0.0\n",
"\n",
"# モデルの保存\n",
"torch.save(model.state_dict(), 'mobilenet_model.pth')"
]
},
{
"cell_type": "markdown",
"id": "f2616b00-8653-4e75-a2a5-a3ff77efad21",
"metadata": {},
"source": [
"テストデータを入力してどこまで正解できるかみてみましょう。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5c99d2c8-e7bf-4cd2-8bf0-2c3a0a11dfc7",
"metadata": {},
"outputs": [],
"source": [
"import torchvision\n",
"dataiter = iter(mobile_test_dataloader)\n",
"images, labels = next(dataiter)\n",
"outputs = model(images)\n",
"_, predicted = torch.max(outputs.data, 1)\n",
"\n",
"# print images\n",
"images = (images + 1) / 2 # -1 - 1 を 0 - 1 に変換\n",
"images = images.clamp(0, 1) # 値を0-1に丸める\n",
"plt.imshow(torchvision.utils.make_grid(images).numpy().transpose([1,2,0]))\n",
"classes = ('plane', 'car', 'bird', 'cat',\n",
" 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
"print('正解: ', ' '.join(f'{classes[labels[j]]:12s}' for j in range(5)))\n",
"print('予測: ', ' '.join(f'{classes[predicted[j]]:12s}' for j in range(5)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "488b04cf-0a06-4122-9e64-53a8b22cda64",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment