Skip to content

Instantly share code, notes, and snippets.

@sagorbrur
Created October 5, 2021 18:35
Show Gist options
  • Select an option

  • Save sagorbrur/0ce5a498ebbb76f63bd44dc7248ea3ba to your computer and use it in GitHub Desktop.

Select an option

Save sagorbrur/0ce5a498ebbb76f63bd44dc7248ea3ba to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import DataLoader\n",
"from torchvision import datasets, transforms\n",
"from torchvision.transforms import ToTensor, Lambda\n"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 3,
"source": [
"training_data = datasets.FashionMNIST(\n",
" root=\"data\",\n",
" train=True,\n",
" download=True,\n",
" transform=ToTensor()\n",
")\n",
"\n",
"test_data = datasets.FashionMNIST(\n",
" root=\"data\",\n",
" train=False,\n",
" download=True,\n",
" transform=ToTensor()\n",
")\n"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"26422272it [00:16, 1644521.31it/s] \n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw\n",
"\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"29696it [00:00, 165681.50it/s] \n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw\n",
"\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"4422656it [00:07, 565707.47it/s] \n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw\n",
"\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"6144it [00:00, 2824088.09it/s] "
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw\n",
"\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n",
"/home/sagor/anaconda3/envs/torch1.9/lib/python3.7/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)\n",
" return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)\n"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 9,
"source": [
"# print(training_data.shape)\n",
"sample_data, sample_label = training_data[0]\n",
"print(sample_data.shape)\n",
"print(sample_label)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"torch.Size([1, 28, 28])\n",
"9\n"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 10,
"source": [
"train_dataloader = DataLoader(training_data, batch_size=32)\n",
"test_dataloader = DataLoader(test_data, batch_size=32)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 13,
"source": [
"class Model(nn.Module):\n",
" def __init__(self):\n",
" super(Model, self).__init__()\n",
" self.flatten = nn.Flatten()\n",
" self.linear_relu_stack = nn.Sequential(\n",
" nn.Linear(28*28, 512),\n",
" nn.ReLU(),\n",
" nn.Linear(512, 512),\n",
" nn.ReLU(),\n",
" nn.Linear(512, 10)\n",
" )\n",
" def forward(self, x):\n",
" x = self.flatten(x)\n",
" x = self.linear_relu_stack(x)\n",
" return x"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 14,
"source": [
"model = Model()\n",
"model"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Model(\n",
" (flatten): Flatten(start_dim=1, end_dim=-1)\n",
" (linear_relu_stack): Sequential(\n",
" (0): Linear(in_features=784, out_features=512, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=512, out_features=512, bias=True)\n",
" (3): ReLU()\n",
" (4): Linear(in_features=512, out_features=10, bias=True)\n",
" )\n",
")"
]
},
"metadata": {},
"execution_count": 14
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 15,
"source": [
"lr = 1e-3\n",
"batch_size = 32\n",
"epochs = 5"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 16,
"source": [
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=lr)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 17,
"source": [
"def train(dataloader, model, criterion, optimizer):\n",
" size = len(dataloader.dataset)\n",
" for batch, (X, y) in enumerate(dataloader):\n",
" # forward pass\n",
" pred = model(X)\n",
" # compute loss\n",
" loss = criterion(pred, y)\n",
" # backpropagation\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" if batch % 100 == 0:\n",
" loss, current = loss.item(), batch * len(X)\n",
" print(f\"loss: {loss:>7f} [{current:>5d}/{size:>5d}\")\n",
"def test(dataloader, model, criterion):\n",
" size = len(dataloader.dataset)\n",
" num_batches = len(dataloader)\n",
" test_loss, correct = 0, 0\n",
" with torch.no_grad():\n",
" for X, y in dataloader:\n",
" pred = model(X)\n",
" test_loss += criterion(pred, y).item()\n",
" correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n",
" test_loss /= num_batches\n",
" correct /= size\n",
" print(f\"Test Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")\n",
"\n"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 18,
"source": [
"for epoch in range(epochs):\n",
" print(f\"Epoch {epoch+1}\\n---------------\")\n",
" train(train_dataloader, model, criterion, optimizer)\n",
" test(test_dataloader, model, criterion)\n",
"print(\"Done!\")"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1\n",
"---------------\n",
"loss: 2.303873 [ 0/60000\n",
"loss: 2.286819 [ 3200/60000\n",
"loss: 2.286289 [ 6400/60000\n",
"loss: 2.263151 [ 9600/60000\n",
"loss: 2.248534 [12800/60000\n",
"loss: 2.233392 [16000/60000\n",
"loss: 2.224118 [19200/60000\n",
"loss: 2.209830 [22400/60000\n",
"loss: 2.211460 [25600/60000\n",
"loss: 2.156247 [28800/60000\n",
"loss: 2.109936 [32000/60000\n",
"loss: 2.093889 [35200/60000\n",
"loss: 2.171871 [38400/60000\n",
"loss: 2.074579 [41600/60000\n",
"loss: 2.032319 [44800/60000\n",
"loss: 2.002094 [48000/60000\n",
"loss: 2.083389 [51200/60000\n",
"loss: 1.976388 [54400/60000\n",
"loss: 1.935207 [57600/60000\n",
"Test Error: \n",
" Accuracy: 55.3%, Avg loss: 1.907281 \n",
"\n",
"Epoch 2\n",
"---------------\n",
"loss: 1.884537 [ 0/60000\n",
"loss: 1.920210 [ 3200/60000\n",
"loss: 1.853379 [ 6400/60000\n",
"loss: 1.795410 [ 9600/60000\n",
"loss: 1.718969 [12800/60000\n",
"loss: 1.774579 [16000/60000\n",
"loss: 1.603558 [19200/60000\n",
"loss: 1.596340 [22400/60000\n",
"loss: 1.604787 [25600/60000\n",
"loss: 1.562681 [28800/60000\n",
"loss: 1.379411 [32000/60000\n",
"loss: 1.328523 [35200/60000\n",
"loss: 1.575898 [38400/60000\n",
"loss: 1.408180 [41600/60000\n",
"loss: 1.299391 [44800/60000\n",
"loss: 1.276043 [48000/60000\n",
"loss: 1.403896 [51200/60000\n",
"loss: 1.369461 [54400/60000\n",
"loss: 1.217832 [57600/60000\n",
"Test Error: \n",
" Accuracy: 61.5%, Avg loss: 1.258526 \n",
"\n",
"Epoch 3\n",
"---------------\n",
"loss: 1.243421 [ 0/60000\n",
"loss: 1.253286 [ 3200/60000\n",
"loss: 1.240637 [ 6400/60000\n",
"loss: 1.294042 [ 9600/60000\n",
"loss: 1.141070 [12800/60000\n",
"loss: 1.307107 [16000/60000\n",
"loss: 1.018114 [19200/60000\n",
"loss: 1.049471 [22400/60000\n",
"loss: 1.115514 [25600/60000\n",
"loss: 1.199444 [28800/60000\n",
"loss: 0.998732 [32000/60000\n",
"loss: 0.959133 [35200/60000\n",
"loss: 1.254790 [38400/60000\n",
"loss: 1.082036 [41600/60000\n",
"loss: 0.986729 [44800/60000\n",
"loss: 0.902049 [48000/60000\n",
"loss: 1.040126 [51200/60000\n",
"loss: 1.143184 [54400/60000\n",
"loss: 0.917088 [57600/60000\n",
"Test Error: \n",
" Accuracy: 65.3%, Avg loss: 0.990519 \n",
"\n",
"Epoch 4\n",
"---------------\n",
"loss: 0.943217 [ 0/60000\n",
"loss: 0.986033 [ 3200/60000\n",
"loss: 0.987671 [ 6400/60000\n",
"loss: 1.122316 [ 9600/60000\n",
"loss: 0.895760 [12800/60000\n",
"loss: 1.104472 [16000/60000\n",
"loss: 0.781212 [19200/60000\n",
"loss: 0.803678 [22400/60000\n",
"loss: 0.901695 [25600/60000\n",
"loss: 1.086160 [28800/60000\n",
"loss: 0.827354 [32000/60000\n",
"loss: 0.785774 [35200/60000\n",
"loss: 1.142647 [38400/60000\n",
"loss: 0.961017 [41600/60000\n",
"loss: 0.867285 [44800/60000\n",
"loss: 0.744188 [48000/60000\n",
"loss: 0.880366 [51200/60000\n",
"loss: 1.035268 [54400/60000\n",
"loss: 0.796331 [57600/60000\n",
"Test Error: \n",
" Accuracy: 67.9%, Avg loss: 0.866105 \n",
"\n",
"Epoch 5\n",
"---------------\n",
"loss: 0.791266 [ 0/60000\n",
"loss: 0.881798 [ 3200/60000\n",
"loss: 0.858656 [ 6400/60000\n",
"loss: 1.008443 [ 9600/60000\n",
"loss: 0.772423 [12800/60000\n",
"loss: 1.004468 [16000/60000\n",
"loss: 0.661393 [19200/60000\n",
"loss: 0.667778 [22400/60000\n",
"loss: 0.787954 [25600/60000\n",
"loss: 1.028648 [28800/60000\n",
"loss: 0.738062 [32000/60000\n",
"loss: 0.694281 [35200/60000\n",
"loss: 1.074223 [38400/60000\n",
"loss: 0.907154 [41600/60000\n",
"loss: 0.817715 [44800/60000\n",
"loss: 0.664417 [48000/60000\n",
"loss: 0.790861 [51200/60000\n",
"loss: 0.963609 [54400/60000\n",
"loss: 0.736180 [57600/60000\n",
"Test Error: \n",
" Accuracy: 70.5%, Avg loss: 0.793863 \n",
"\n",
"Done!\n"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 19,
"source": [
"torch.save(model.state_dict(), \"trained_model.pt\")"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 20,
"source": [
"# inference\n",
"load_model = Model()\n",
"load_model.load_state_dict(torch.load(\"trained_model.pt\"))\n",
"sample_test_data, sample_test_label = test_dataloader.dataset[0]\n",
"\n",
"with torch.no_grad():\n",
" pred = load_model(sample_test_data)\n",
" print(f\"prediction: {pred.argmax(1)}\")\n",
" print(f\"real output: {sample_test_label}\")\n"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"prediction: tensor([9])\n",
"real output: 9\n"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [],
"outputs": [],
"metadata": {}
}
],
"metadata": {
"orig_nbformat": 4,
"language_info": {
"name": "python",
"version": "3.7.11",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3.7.11 64-bit ('torch1.9': conda)"
},
"interpreter": {
"hash": "6852f41e9441d5e9d418b4dd84330779d45168a2f143dd1c731588fbff184199"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment