Created
October 5, 2021 18:35
-
-
Save sagorbrur/0ce5a498ebbb76f63bd44dc7248ea3ba to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "source": [ | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "from torch.utils.data import DataLoader\n", | |
| "from torchvision import datasets, transforms\n", | |
| "from torchvision.transforms import ToTensor, Lambda\n" | |
| ], | |
| "outputs": [], | |
| "metadata": {} | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "source": [ | |
| "training_data = datasets.FashionMNIST(\n", | |
| " root=\"data\",\n", | |
| " train=True,\n", | |
| " download=True,\n", | |
| " transform=ToTensor()\n", | |
| ")\n", | |
| "\n", | |
| "test_data = datasets.FashionMNIST(\n", | |
| " root=\"data\",\n", | |
| " train=False,\n", | |
| " download=True,\n", | |
| " transform=ToTensor()\n", | |
| ")\n" | |
| ], | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz\n", | |
| "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "26422272it [00:16, 1644521.31it/s] \n" | |
| ] | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw\n", | |
| "\n", | |
| "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz\n", | |
| "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "29696it [00:00, 165681.50it/s] \n" | |
| ] | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw\n", | |
| "\n", | |
| "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz\n", | |
| "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "4422656it [00:07, 565707.47it/s] \n" | |
| ] | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw\n", | |
| "\n", | |
| "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz\n", | |
| "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "6144it [00:00, 2824088.09it/s] " | |
| ] | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "\n", | |
| "/home/sagor/anaconda3/envs/torch1.9/lib/python3.7/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)\n", | |
| " return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)\n" | |
| ] | |
| } | |
| ], | |
| "metadata": {} | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "source": [ | |
| "# print(training_data.shape)\n", | |
| "sample_data, sample_label = training_data[0]\n", | |
| "print(sample_data.shape)\n", | |
| "print(sample_label)" | |
| ], | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "torch.Size([1, 28, 28])\n", | |
| "9\n" | |
| ] | |
| } | |
| ], | |
| "metadata": {} | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "source": [ | |
| "train_dataloader = DataLoader(training_data, batch_size=32)\n", | |
| "test_dataloader = DataLoader(test_data, batch_size=32)" | |
| ], | |
| "outputs": [], | |
| "metadata": {} | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "source": [ | |
| "class Model(nn.Module):\n", | |
| " def __init__(self):\n", | |
| " super(Model, self).__init__()\n", | |
| " self.flatten = nn.Flatten()\n", | |
| " self.linear_relu_stack = nn.Sequential(\n", | |
| " nn.Linear(28*28, 512),\n", | |
| " nn.ReLU(),\n", | |
| " nn.Linear(512, 512),\n", | |
| " nn.ReLU(),\n", | |
| " nn.Linear(512, 10)\n", | |
| " )\n", | |
| " def forward(self, x):\n", | |
| " x = self.flatten(x)\n", | |
| " x = self.linear_relu_stack(x)\n", | |
| " return x" | |
| ], | |
| "outputs": [], | |
| "metadata": {} | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "source": [ | |
| "model = Model()\n", | |
| "model" | |
| ], | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "Model(\n", | |
| " (flatten): Flatten(start_dim=1, end_dim=-1)\n", | |
| " (linear_relu_stack): Sequential(\n", | |
| " (0): Linear(in_features=784, out_features=512, bias=True)\n", | |
| " (1): ReLU()\n", | |
| " (2): Linear(in_features=512, out_features=512, bias=True)\n", | |
| " (3): ReLU()\n", | |
| " (4): Linear(in_features=512, out_features=10, bias=True)\n", | |
| " )\n", | |
| ")" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 14 | |
| } | |
| ], | |
| "metadata": {} | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "source": [ | |
| "lr = 1e-3\n", | |
| "batch_size = 32\n", | |
| "epochs = 5" | |
| ], | |
| "outputs": [], | |
| "metadata": {} | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "source": [ | |
| "criterion = nn.CrossEntropyLoss()\n", | |
| "optimizer = torch.optim.SGD(model.parameters(), lr=lr)" | |
| ], | |
| "outputs": [], | |
| "metadata": {} | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "source": [ | |
| "def train(dataloader, model, criterion, optimizer):\n", | |
| " size = len(dataloader.dataset)\n", | |
| " for batch, (X, y) in enumerate(dataloader):\n", | |
| " # forward pass\n", | |
| " pred = model(X)\n", | |
| " # compute loss\n", | |
| " loss = criterion(pred, y)\n", | |
| " # backpropagation\n", | |
| " optimizer.zero_grad()\n", | |
| " loss.backward()\n", | |
| " optimizer.step()\n", | |
| " if batch % 100 == 0:\n", | |
| " loss, current = loss.item(), batch * len(X)\n", | |
| " print(f\"loss: {loss:>7f} [{current:>5d}/{size:>5d}\")\n", | |
| "def test(dataloader, model, criterion):\n", | |
| " size = len(dataloader.dataset)\n", | |
| " num_batches = len(dataloader)\n", | |
| " test_loss, correct = 0, 0\n", | |
| " with torch.no_grad():\n", | |
| " for X, y in dataloader:\n", | |
| " pred = model(X)\n", | |
| " test_loss += criterion(pred, y).item()\n", | |
| " correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n", | |
| " test_loss /= num_batches\n", | |
| " correct /= size\n", | |
| " print(f\"Test Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")\n", | |
| "\n" | |
| ], | |
| "outputs": [], | |
| "metadata": {} | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "source": [ | |
| "for epoch in range(epochs):\n", | |
| " print(f\"Epoch {epoch+1}\\n---------------\")\n", | |
| " train(train_dataloader, model, criterion, optimizer)\n", | |
| " test(test_dataloader, model, criterion)\n", | |
| "print(\"Done!\")" | |
| ], | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Epoch 1\n", | |
| "---------------\n", | |
| "loss: 2.303873 [ 0/60000\n", | |
| "loss: 2.286819 [ 3200/60000\n", | |
| "loss: 2.286289 [ 6400/60000\n", | |
| "loss: 2.263151 [ 9600/60000\n", | |
| "loss: 2.248534 [12800/60000\n", | |
| "loss: 2.233392 [16000/60000\n", | |
| "loss: 2.224118 [19200/60000\n", | |
| "loss: 2.209830 [22400/60000\n", | |
| "loss: 2.211460 [25600/60000\n", | |
| "loss: 2.156247 [28800/60000\n", | |
| "loss: 2.109936 [32000/60000\n", | |
| "loss: 2.093889 [35200/60000\n", | |
| "loss: 2.171871 [38400/60000\n", | |
| "loss: 2.074579 [41600/60000\n", | |
| "loss: 2.032319 [44800/60000\n", | |
| "loss: 2.002094 [48000/60000\n", | |
| "loss: 2.083389 [51200/60000\n", | |
| "loss: 1.976388 [54400/60000\n", | |
| "loss: 1.935207 [57600/60000\n", | |
| "Test Error: \n", | |
| " Accuracy: 55.3%, Avg loss: 1.907281 \n", | |
| "\n", | |
| "Epoch 2\n", | |
| "---------------\n", | |
| "loss: 1.884537 [ 0/60000\n", | |
| "loss: 1.920210 [ 3200/60000\n", | |
| "loss: 1.853379 [ 6400/60000\n", | |
| "loss: 1.795410 [ 9600/60000\n", | |
| "loss: 1.718969 [12800/60000\n", | |
| "loss: 1.774579 [16000/60000\n", | |
| "loss: 1.603558 [19200/60000\n", | |
| "loss: 1.596340 [22400/60000\n", | |
| "loss: 1.604787 [25600/60000\n", | |
| "loss: 1.562681 [28800/60000\n", | |
| "loss: 1.379411 [32000/60000\n", | |
| "loss: 1.328523 [35200/60000\n", | |
| "loss: 1.575898 [38400/60000\n", | |
| "loss: 1.408180 [41600/60000\n", | |
| "loss: 1.299391 [44800/60000\n", | |
| "loss: 1.276043 [48000/60000\n", | |
| "loss: 1.403896 [51200/60000\n", | |
| "loss: 1.369461 [54400/60000\n", | |
| "loss: 1.217832 [57600/60000\n", | |
| "Test Error: \n", | |
| " Accuracy: 61.5%, Avg loss: 1.258526 \n", | |
| "\n", | |
| "Epoch 3\n", | |
| "---------------\n", | |
| "loss: 1.243421 [ 0/60000\n", | |
| "loss: 1.253286 [ 3200/60000\n", | |
| "loss: 1.240637 [ 6400/60000\n", | |
| "loss: 1.294042 [ 9600/60000\n", | |
| "loss: 1.141070 [12800/60000\n", | |
| "loss: 1.307107 [16000/60000\n", | |
| "loss: 1.018114 [19200/60000\n", | |
| "loss: 1.049471 [22400/60000\n", | |
| "loss: 1.115514 [25600/60000\n", | |
| "loss: 1.199444 [28800/60000\n", | |
| "loss: 0.998732 [32000/60000\n", | |
| "loss: 0.959133 [35200/60000\n", | |
| "loss: 1.254790 [38400/60000\n", | |
| "loss: 1.082036 [41600/60000\n", | |
| "loss: 0.986729 [44800/60000\n", | |
| "loss: 0.902049 [48000/60000\n", | |
| "loss: 1.040126 [51200/60000\n", | |
| "loss: 1.143184 [54400/60000\n", | |
| "loss: 0.917088 [57600/60000\n", | |
| "Test Error: \n", | |
| " Accuracy: 65.3%, Avg loss: 0.990519 \n", | |
| "\n", | |
| "Epoch 4\n", | |
| "---------------\n", | |
| "loss: 0.943217 [ 0/60000\n", | |
| "loss: 0.986033 [ 3200/60000\n", | |
| "loss: 0.987671 [ 6400/60000\n", | |
| "loss: 1.122316 [ 9600/60000\n", | |
| "loss: 0.895760 [12800/60000\n", | |
| "loss: 1.104472 [16000/60000\n", | |
| "loss: 0.781212 [19200/60000\n", | |
| "loss: 0.803678 [22400/60000\n", | |
| "loss: 0.901695 [25600/60000\n", | |
| "loss: 1.086160 [28800/60000\n", | |
| "loss: 0.827354 [32000/60000\n", | |
| "loss: 0.785774 [35200/60000\n", | |
| "loss: 1.142647 [38400/60000\n", | |
| "loss: 0.961017 [41600/60000\n", | |
| "loss: 0.867285 [44800/60000\n", | |
| "loss: 0.744188 [48000/60000\n", | |
| "loss: 0.880366 [51200/60000\n", | |
| "loss: 1.035268 [54400/60000\n", | |
| "loss: 0.796331 [57600/60000\n", | |
| "Test Error: \n", | |
| " Accuracy: 67.9%, Avg loss: 0.866105 \n", | |
| "\n", | |
| "Epoch 5\n", | |
| "---------------\n", | |
| "loss: 0.791266 [ 0/60000\n", | |
| "loss: 0.881798 [ 3200/60000\n", | |
| "loss: 0.858656 [ 6400/60000\n", | |
| "loss: 1.008443 [ 9600/60000\n", | |
| "loss: 0.772423 [12800/60000\n", | |
| "loss: 1.004468 [16000/60000\n", | |
| "loss: 0.661393 [19200/60000\n", | |
| "loss: 0.667778 [22400/60000\n", | |
| "loss: 0.787954 [25600/60000\n", | |
| "loss: 1.028648 [28800/60000\n", | |
| "loss: 0.738062 [32000/60000\n", | |
| "loss: 0.694281 [35200/60000\n", | |
| "loss: 1.074223 [38400/60000\n", | |
| "loss: 0.907154 [41600/60000\n", | |
| "loss: 0.817715 [44800/60000\n", | |
| "loss: 0.664417 [48000/60000\n", | |
| "loss: 0.790861 [51200/60000\n", | |
| "loss: 0.963609 [54400/60000\n", | |
| "loss: 0.736180 [57600/60000\n", | |
| "Test Error: \n", | |
| " Accuracy: 70.5%, Avg loss: 0.793863 \n", | |
| "\n", | |
| "Done!\n" | |
| ] | |
| } | |
| ], | |
| "metadata": {} | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "source": [ | |
| "torch.save(model.state_dict(), \"trained_model.pt\")" | |
| ], | |
| "outputs": [], | |
| "metadata": {} | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "source": [ | |
| "# inference\n", | |
| "load_model = Model()\n", | |
| "load_model.load_state_dict(torch.load(\"trained_model.pt\"))\n", | |
| "sample_test_data, sample_test_label = test_dataloader.dataset[0]\n", | |
| "\n", | |
| "with torch.no_grad():\n", | |
| " pred = load_model(sample_test_data)\n", | |
| " print(f\"prediction: {pred.argmax(1)}\")\n", | |
| " print(f\"real output: {sample_test_label}\")\n" | |
| ], | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "prediction: tensor([9])\n", | |
| "real output: 9\n" | |
| ] | |
| } | |
| ], | |
| "metadata": {} | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "source": [], | |
| "outputs": [], | |
| "metadata": {} | |
| } | |
| ], | |
| "metadata": { | |
| "orig_nbformat": 4, | |
| "language_info": { | |
| "name": "python", | |
| "version": "3.7.11", | |
| "mimetype": "text/x-python", | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "pygments_lexer": "ipython3", | |
| "nbconvert_exporter": "python", | |
| "file_extension": ".py" | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3.7.11 64-bit ('torch1.9': conda)" | |
| }, | |
| "interpreter": { | |
| "hash": "6852f41e9441d5e9d418b4dd84330779d45168a2f143dd1c731588fbff184199" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment