Skip to content

Instantly share code, notes, and snippets.

@SaiNikhileshReddy
Created March 9, 2022 03:41
Show Gist options
  • Save SaiNikhileshReddy/28b1d821c14ca48e6dca3fd93a2dbb64 to your computer and use it in GitHub Desktop.
Save SaiNikhileshReddy/28b1d821c14ca48e6dca3fd93a2dbb64 to your computer and use it in GitHub Desktop.
MNIST PyTorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"id": "525ba641-b696-4833-8829-2968104ff67d",
"metadata": {},
"outputs": [],
"source": [
"# Required Packages\n",
"import torch\n",
"from torch import nn, optim\n",
"from torch.nn import functional as F\n",
"from collections import OrderedDict\n",
"\n",
"from torchvision import datasets, transforms\n",
"from matplotlib import pyplot as plt\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "1630189c-f853-4dbc-bc5e-375a54db2a1e",
"metadata": {},
"outputs": [],
"source": [
"transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.5,),(0.5,))\n",
"])\n",
"\n",
"trainset = datasets.MNIST('datasets', download=True, train=True, transform=transform)\n",
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)\n",
"\n",
"testset = datasets.MNIST('datasets', train=False, transform=transform)\n",
"testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "96880429-ae18-4242-b8ee-372f0e9d1c17",
"metadata": {},
"outputs": [],
"source": [
"image, label = next(iter(trainloader))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "2378e6fb-c815-4bf1-b0c0-9a67eefdae04",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train : 1\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAALfklEQVR4nO3dX4gd5R3G8edpqgjqRf7QZYlpjZKbJZBYYigkFIsoaRCiN2KEkNLQ9UJBoRcNFjRQBCnVUm+EFYOxWEVQaxBBkyBNcxOyaho3fzSpRJKw7lZyYQQhUX+9OBNZ4+6czZk5Z477+37gcOa87+zMjyFP3vlzdl9HhADMfT9qugAAvUHYgSQIO5AEYQeSIOxAEj/u5c5sc+sf6LKI8HTtlUZ22+tsf2j7hO2tVbYFoLvc6XN22/MkfSTpNkmnJR2QtDEijpT8DCM70GXdGNlXSzoRER9HxHlJL0naUGF7ALqoStgXSzo15fPpou07bA/bHrU9WmFfACrq+g26iBiRNCJxGg80qcrIfkbSkimfryvaAPShKmE/IGmZ7aW2r5R0j6Sd9ZQFoG4dn8ZHxFe2H5D0lqR5krZHxOHaKgNQq44fvXW0M67Zga7rypdqAPxwEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5BEx1M244dh3rx5pf2PPPJIaf+9995b2r9mzZrS/snJydJ+9E6lsNs+KemcpK8lfRURq+ooCkD96hjZfxURn9WwHQBdxDU7kETVsIekt22/a3t4uhVsD9setT1acV8AKqh6Gr82Is7Y/omkXbaPRcTeqStExIikEUmyHRX3B6BDlUb2iDhTvE9Kek3S6jqKAlC/jsNu+2rb115clnS7pLG6CgNQL0d0dmZt+wa1RnOpdTnwj4h4rM3PcBrfY0NDQ6X9Y2PV/n9ut/1jx45V2j4uX0R4uvaOr9kj4mNJKzquCEBP8egNSIKwA0kQdiAJwg4kQdiBJPgV1zlucHCw6RLQJxjZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJnrPPcZs2bWq6BPQJRnYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSbQNu+3ttidtj01pW2B7l+3jxfv87pYJoKrZjOzPSVp3SdtWSXsiYpmkPcVnAH2sbdgjYq+ks5c0b5C0o1jeIenOessCULdO/wbdQESMF8ufShqYaUXbw5KGO9wPgJpU/oOTERG2o6R/RNKIJJWtB6C7Or0bP2F7UJKK98n6SgLQDZ2GfaekzcXyZkmv11MOgG5xRPmZte0XJd0iaZGkCUmPSvqnpJcl/VTSJ5LujohLb+JNty1O43tsxYoVpf3vv/9+pe0PDQ2V9h87dqzS9nH5IsLTtbe9Zo+IjTN03VqpIgA9xTfogCQIO5AEYQeSIOxAEoQdSIIpm+e4CxcuNF0C+gQjO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwXN2VLJo0aKmS8AsMbIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBI8Z5/jTp06Vdp/4MCB0v6bb765tH/Lli2l/fv27SvtR+8wsgNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEjxnn+POnTtX2j8xMdGjStC0tiO77e22J22PTWnbZvuM7YPFa313ywRQ1WxO45+TtG6a9r9GxMri9Wa9ZQGoW9uwR8ReSWd7UAuALqpyg+4B24eK0/z5M61ke9j2qO3RCvsCUFGnYX9a0o2SVkoal/TETCtGxEhErIqIVR3uC0ANOgp7RExExNcR8Y2kZyStrrcsAHXrKOy2B6d8vEvS2EzrAugPbZ+z235R0i2SFtk+LelRSbfYXikpJJ2UdF/3SgRQh7Zhj4iN0zQ/24VaAHQRX5cFkiDsQBKEHUiCsANJEHYgCX7FdY5buHBhaf/SpUt7VAmaxsgOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0nwnH2OGxgYKO1fvnx5aX9E1FkOGsTIDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ8Jx9jjt//nxp/5dfflnaf9VVV9VZDhrEyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSfCcfY47ceJEaf/u3btL+++44446y0GD2o7stpfYfsf2EduHbT9YtC+wvcv28eJ9fvfLBdCp2ZzGfyXp9xExJOkXku63PSRpq6Q9EbFM0p7iM4A+1TbsETEeEe8Vy+ckHZW0WNIGSTuK1XZIurNLNQKowWVds9u+XtJNkvZLGoiI8aLrU0nT/rEz28OShivUCKAGs74bb/saSa9IeigiPp/aF62/SjjtXyaMiJGIWBURqypVCqCSWYXd9hVqBf2FiHi1aJ6wPVj0D0qa7E6JAOowm7vxlvSspKMR8eSUrp2SNhfLmyW9Xn95AOoym2v2NZI2SfrA9sGi7WFJj0t62fYWSZ9IursrFQKoRduwR8Q+SZ6h+9Z6ywHQLXxdFkiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkmLI5uaeeeqq0v92Uzfv376+zHHQRIzuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJOGIKF/BXiLpeUkDkkLSSET8zfY2Sb+T9L9i1Ycj4s022yrfGYDKImLaWZdnE/ZBSYMR8Z7tayW9K+lOteZj/yIi/jLbIgg70H0zhX0287OPSxovls/ZPippcb3lAei2y7pmt329pJskXfyO5AO2D9nebnv+DD8zbHvU9mi1UgFU0fY0/tsV7Wsk/UvSYxHxqu0BSZ+pdR3/J7VO9X/bZhucxgNd1vE1uyTZvkLSG5Leiognp+m/XtIbEbG8zXYIO9BlM4W97Wm8bUt6VtLRqUEvbtxddJeksapFAuie2dyNXyvp35I+kPRN0fywpI2SVqp1Gn9S0n3FzbyybTGyA11W6TS+LoQd6L6OT+MBzA2EHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJHo9ZfNnkj6Z8nlR0daP+rW2fq1LorZO1Vnbz2bq6Onvs39v5/ZoRKxqrIAS/Vpbv9YlUVunelUbp/FAEoQdSKLpsI80vP8y/Vpbv9YlUVunelJbo9fsAHqn6ZEdQI8QdiCJRsJue53tD22fsL21iRpmYvuk7Q9sH2x6frpiDr1J22NT2hbY3mX7ePE+7Rx7DdW2zfaZ4tgdtL2+odqW2H7H9hHbh20/WLQ3euxK6urJcev5NbvteZI+knSbpNOSDkjaGBFHelrIDGyflLQqIhr/AobtX0r6QtLzF6fWsv1nSWcj4vHiP8r5EfGHPqltmy5zGu8u1TbTNOO/UYPHrs7pzzvRxMi+WtKJiPg4Is5LeknShgbq6HsRsVfS2UuaN0jaUSzvUOsfS8/NUFtfiIjxiHivWD4n6eI0440eu5K6eqKJsC+WdGrK59Pqr/neQ9Lbtt+1Pdx0MdMYmDLN1qeSBposZhptp/HupUumGe+bY9fJ9OdVcYPu+9ZGxM8l/VrS/cXpal+K1jVYPz07fVrSjWrNATgu6YkmiymmGX9F0kMR8fnUviaP3TR19eS4NRH2M5KWTPl8XdHWFyLiTPE+Kek1tS47+snExRl0i/fJhuv5VkRMRMTXEfGNpGfU4LErphl/RdILEfFq0dz4sZuurl4dtybCfkDSMttLbV8p6R5JOxuo43tsX13cOJHtqyXdrv6binqnpM3F8mZJrzdYy3f0yzTeM00zroaPXePTn0dEz1+S1qt1R/6/kv7YRA0z1HWDpP8Ur8NN1ybpRbVO6y6odW9ji6SFkvZIOi5pt6QFfVTb39Wa2vuQWsEabKi2tWqdoh+SdLB4rW/62JXU1ZPjxtdlgSS4QQckQdiBJAg7kARhB5Ig7EAShB1IgrADSfwfSwmW1y1rvbkAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test : 7\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAALr0lEQVR4nO3dT6gd5R3G8eep1Y26SCq9hBiqlWyk0FguoVBpEkRJs4luxCxKCtLrQouCiwa7uMlOSlW6Eq4YjMUqgopZSGsaYtJuJFdJY/6gSUPEhJhbycK4stFfF2cix3jPmZszM2fm3t/3A5dzzjvnnPk5+jhz5p13XkeEACx932u7AADjQdiBJAg7kARhB5Ig7EAS3x/nymxz6h9oWER4vvZKe3bbG21/aPuk7W1VvgtAszxqP7vtayR9JOluSWckHZS0JSKODfkMe3agYU3s2ddKOhkRpyLiS0mvSNpc4fsANKhK2FdK+qTv9Zmi7VtsT9metT1bYV0AKmr8BF1EzEiakTiMB9pUZc9+VtKqvtc3F20AOqhK2A9KWm37VtvXSXpA0u56ygJQt5EP4yPiku1HJP1d0jWSdkbE0doqA1CrkbveRloZv9mBxjVyUQ2AxYOwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kMTI87NLku3Tki5K+krSpYiYrKMoAPWrFPbChoj4rIbvAdAgDuOBJKqGPSS9bfs921PzvcH2lO1Z27MV1wWgAkfE6B+2V0bEWds/lLRH0u8i4sCQ94++MgALEhGer73Snj0izhaPc5LekLS2yvcBaM7IYbd9ve0bLz+XdI+kI3UVBqBeVc7GT0h6w/bl7/lrRPytlqpQm+3btw9d/s4771RaXmb9+vUjLZOkdevWjfzdkrRhw4aBy6r+cy1GI4c9Ik5J+mmNtQBoEF1vQBKEHUiCsANJEHYgCcIOJFHpCrqrXhlX0M2rrHusiqrdV0tV0WW8JDVyBR2AxYOwA0kQdiAJwg4kQdiBJAg7kARhB5Kgn70DxvnvAD30swNYsgg7kARhB5Ig7EAShB1IgrADSRB2IIk6JnZERWW3NS4bc75jx46By8rGs+/fv3/o8jJVbkVdNo5/enr66gvqM2y7ZMSeHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYDw7WtP0f3tLecz6MCOPZ7e90/ac7SN9bctt77F9onhcVmexAOq3kMP4FyRtvKJtm6S9EbFa0t7iNYAOKw17RByQdOGK5s2SdhXPd0m6t96yANRt1GvjJyLiXPH8U0kTg95oe0rS1IjrAVCTygNhIiKGnXiLiBlJMxIn6IA2jdr1dt72CkkqHufqKwlAE0YN+25JW4vnWyW9WU85AJpSehhv+2VJ6yXdZPuMpGlJT0p61faDkj6WdH+TRWLxanL+97Kx9Pi20rBHxJYBi+6quRYADeJyWSAJwg4kQdiBJAg7kARhB5LgVtJoVNXbQQ/DraKvDnt2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCfnZUUjbtcpUhrmX96AxxvTrs2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCfrZUcm6desa+2760evFnh1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkqCfHUOVjUevMl69rB+9aj972Vj7rn53U0r37LZ32p6zfaSvbbvts7YPFX+bmi0TQFULOYx/QdLGedqfiYg1xd9b9ZYFoG6lYY+IA5IujKEWAA2qcoLuEduHi8P8ZYPeZHvK9qzt2QrrAlDRqGF/VtJtktZIOifpqUFvjIiZiJiMiMkR1wWgBiOFPSLOR8RXEfG1pOckra23LAB1Gynstlf0vbxP0pFB7wXQDaX97LZflrRe0k22z0ialrTe9hpJIem0pIeaKxFt2rdvX2PfXdZHHxGNrbusD3///v2NrbstpWGPiC3zND/fQC0AGsTlskAShB1IgrADSRB2IAnCDiTBENcaVB0GWnY75irDSJeysu6zDRs2jKeQRYI9O5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQT/7Ag0b6kk/+GDD+sLLhpEuxts1dxl7diAJwg4kQdiBJAg7kARhB5Ig7EAShB1Ign72Qtktk9vsS69y2+Omx8rv2LFj6HL6yruDPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEE/e6Gsv3hYf3TZZ8v6ycuWV1HWz121n73J2lGv0j277VW299k+Zvuo7UeL9uW299g+UTwua75cAKNayGH8JUmPR8Ttkn4u6WHbt0vaJmlvRKyWtLd4DaCjSsMeEeci4v3i+UVJxyWtlLRZ0q7ibbsk3dtQjQBqcFW/2W3fIukOSe9KmoiIc8WiTyVNDPjMlKSpCjUCqMGCz8bbvkHSa5Iei4jP+5dFREiK+T4XETMRMRkRk5UqBVDJgsJu+1r1gv5SRLxeNJ+3vaJYvkLSXDMlAqhD6WG8bUt6XtLxiHi6b9FuSVslPVk8vtlIhWNS1oXU2wyLT9kQ1zJVuxXRHQv5zf4LSb+W9IHtQ0XbE+qF/FXbD0r6WNL9jVQIoBalYY+If0katFu7q95yADSFy2WBJAg7kARhB5Ig7EAShB1Iwr2L38a0Mnt8K0tk2DDVsltkl1ms1xdkFhHz/ktjzw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSXAr6SVgenq67RKwCLBnB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEk6GdPruy+8Fg62LMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBILmZ99laQXJU1ICkkzEfFn29sl/VbSf4u3PhERbzVVKAYb1lc+7J7yyGUhF9VckvR4RLxv+0ZJ79neUyx7JiL+1Fx5AOqykPnZz0k6Vzy/aPu4pJVNFwagXlf1m932LZLukPRu0fSI7cO2d9peNuAzU7Znbc9WKxVAFQsOu+0bJL0m6bGI+FzSs5Juk7RGvT3/U/N9LiJmImIyIiarlwtgVAsKu+1r1Qv6SxHxuiRFxPmI+Coivpb0nKS1zZUJoKrSsLs3jefzko5HxNN97Sv63nafpCP1lwegLqVTNtu+U9I/JX0g6eui+QlJW9Q7hA9JpyU9VJzMG/ZdTNkMNGzQlM3Mzw4sMczPDiRH2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSGLcUzZ/Junjvtc3FW1d1NXaulqXRG2jqrO2Hw1aMNbx7N9ZuT3b1XvTdbW2rtYlUduoxlUbh/FAEoQdSKLtsM+0vP5hulpbV+uSqG1UY6mt1d/sAMan7T07gDEh7EASrYTd9kbbH9o+aXtbGzUMYvu07Q9sH2p7frpiDr0520f62pbb3mP7RPE47xx7LdW23fbZYtsdsr2ppdpW2d5n+5jto7YfLdpb3XZD6hrLdhv7b3bb10j6SNLdks5IOihpS0QcG2shA9g+LWkyIlq/AMP2LyV9IenFiPhJ0fZHSRci4snif5TLIuL3Haltu6Qv2p7Gu5itaEX/NOOS7pX0G7W47YbUdb/GsN3a2LOvlXQyIk5FxJeSXpG0uYU6Oi8iDki6cEXzZkm7iue71PuPZewG1NYJEXEuIt4vnl+UdHma8Va33ZC6xqKNsK+U9Enf6zPq1nzvIelt2+/Znmq7mHlM9E2z9amkiTaLmUfpNN7jdMU0453ZdqNMf14VJ+i+686I+JmkX0l6uDhc7aTo/QbrUt/pgqbxHpd5phn/RpvbbtTpz6tqI+xnJa3qe31z0dYJEXG2eJyT9Ia6NxX1+csz6BaPcy3X840uTeM93zTj6sC2a3P68zbCflDSatu32r5O0gOSdrdQx3fYvr44cSLb10u6R92binq3pK3F862S3myxlm/pyjTeg6YZV8vbrvXpzyNi7H+SNql3Rv4/kv7QRg0D6vqxpH8Xf0fbrk3Sy+od1v1PvXMbD0r6gaS9kk5I+oek5R2q7S/qTe19WL1grWiptjvVO0Q/LOlQ8bep7W03pK6xbDculwWS4AQdkARhB5Ig7EAShB1IgrADSRB2IAnCDiTxf7GM27LIURnHAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Checking Images\n",
"image, label = next(iter(trainloader))\n",
"index = 0\n",
"print(f'Train : {label[index]}')\n",
"plt.imshow(image[index].numpy().squeeze(), cmap='gray')\n",
"plt.show()\n",
"\n",
"image, label = next(iter(testloader))\n",
"index = 0\n",
"print(f'Test : {label[index]}')\n",
"plt.imshow(image[index].numpy().squeeze(), cmap='gray')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "b4094178-464d-441c-8550-49dcda3ea140",
"metadata": {},
"outputs": [],
"source": [
"class Network(nn.Module):\n",
" def __init__(self, input_layer, output_layer, hidden_layers, drop_p=0.5):\n",
" super().__init__()\n",
" \n",
" layer_sizes = zip(hidden_layers[:-1], hidden_layers[1:])\n",
" self.hidden_layers = nn.ModuleList([nn.Linear(input_layer, hidden_layers[0])])\n",
" self.hidden_layers.extend([nn.Linear(h1, h2) for h1, h2 in layer_sizes])\n",
" self.output_layer = nn.Linear(hidden_layers[-1], output_layer)\n",
" \n",
" self.dropout = nn.Dropout(p=drop_p)\n",
" \n",
" def forward(self, images):\n",
" x = images.view(images.shape[0], -1)\n",
" \n",
" for hidden_layer in self.hidden_layers:\n",
" x = F.relu(hidden_layer(x))\n",
" x = self.dropout(x)\n",
" \n",
" x = self.output_layer(x)\n",
" \n",
" return F.log_softmax(x, dim=1)\n",
" \n",
"def validation(model, testloader, criterion):\n",
" loss = 0\n",
" accuracy = 0\n",
" \n",
" for images, labels in iter(testloader):\n",
" logits = model.forward(images)\n",
" loss += criterion(logits, labels)\n",
" ps = torch.exp(logits)\n",
" equality = (labels.data == ps.max(1)[1])\n",
" accuracy += equality.type_as(torch.FloatTensor()).mean()\n",
" \n",
" return loss, accuracy\n",
"\n",
"def train(model, trainloader, testloader, criterion, optimizer, epochs=5):\n",
" train_loss = []\n",
" test_loss = []\n",
" accuracy = []\n",
" \n",
" for epoch in range(1, epochs+1):\n",
" running_loss = 0\n",
" model.train()\n",
" \n",
" for images, labels in iter(trainloader):\n",
" \n",
" model.zero_grad()\n",
" logits = model.forward(images)\n",
" loss = criterion(logits, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" running_loss += loss.item()\n",
" \n",
" else:\n",
" model.eval()\n",
" with torch.no_grad():\n",
" loss, test_accuracy = validation(model, testloader, criterion)\n",
" test_loss.append(loss/len(testloader))\n",
" accuracy.append(test_accuracy/len(testloader))\n",
" \n",
" train_loss.append(running_loss/len(trainloader))\n",
"\n",
" print('Epoch: {}/{}'.format(epoch, epochs),\n",
" 'Train: {:.5f}..'.format(train_loss[-1]),\n",
" 'Test: {:.5f}..'.format(test_loss[-1]),\n",
" 'Accuracy: {:.7f}..'.format(accuracy[-1]))\n",
" \n",
" return model, train_loss, test_loss, accuracy"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "def868f2-ffaf-48a1-8a67-2bd768faae74",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Network(\n",
" (hidden_layers): ModuleList(\n",
" (0): Linear(in_features=784, out_features=64, bias=True)\n",
" (1): Linear(in_features=64, out_features=32, bias=True)\n",
" )\n",
" (output_layer): Linear(in_features=32, out_features=10, bias=True)\n",
" (dropout): Dropout(p=0.5, inplace=False)\n",
")\n"
]
}
],
"source": [
"model = Network(784, 10, [64, 32], drop_p=0.5)\n",
"criterion = nn.NLLLoss()\n",
"optimizer = optim.SGD(model.parameters(), lr=0.05)\n",
"print(model)"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "c88193ff-64db-4a32-bdc5-8e431a76b530",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 1/5 Train: 1.12936.. Test: 0.42069.. Accuracy: 0.8886346..\n",
"Epoch: 2/5 Train: 0.71243.. Test: 0.32782.. Accuracy: 0.9113256..\n",
"Epoch: 3/5 Train: 0.61807.. Test: 0.30322.. Accuracy: 0.9084395..\n",
"Epoch: 4/5 Train: 0.57504.. Test: 0.26841.. Accuracy: 0.9211783..\n",
"Epoch: 5/5 Train: 0.55544.. Test: 0.27302.. Accuracy: 0.9183917..\n"
]
}
],
"source": [
"model, _, _, _ = train(model, trainloader, testloader, criterion, optimizer, epochs=5)"
]
},
{
"cell_type": "code",
"execution_count": 60,
"id": "03b845df-0366-4fbe-b519-3c8348460817",
"metadata": {},
"outputs": [],
"source": [
"# Inference\n",
"images, labels = next(iter(testloader))\n",
"\n",
"model.eval()\n",
"logits = model.forward(images)\n",
"loss = criterion(logits, labels)\n",
"ps = torch.exp(logits)\n",
"equality = (labels.data == ps.max(1)[1])\n",
"accuracy = equality.type_as(torch.FloatTensor())"
]
},
{
"cell_type": "code",
"execution_count": 62,
"id": "bb819858-78f9-4ca1-93bf-82f4d1abe624",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Label : 9\n",
"Prediciton : 9 | Accuracy : 1.0\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAANFElEQVR4nO3dbcgd9ZnH8d9vtXlhWiWubAhp2NTGB6qgNSEsGESprW4QY0C0EZZUxVSoS4sFFRXiS1m0YX1h4a4PTZZuSrH1gVBrsiHi6otqImmM0SYmRJoQE2PAPPgim3jti3sst+Y+/3PnzMw5J7m+H7g558x1ZubikF9mzsyc+TsiBOD09w+DbgBAfxB2IAnCDiRB2IEkCDuQxJn9XJltDv0DLYsIjze91pbd9vW2/2r7A9sP1FkWgHa51/Psts+QtFXS9yXtkvSWpEURsaUwD1t2oGVtbNnnSvogInZExFFJv5W0oMbyALSoTtinS/rbmNe7qmlfYnuJ7fW219dYF4CaWj9AFxEjkkYkduOBQaqzZd8tacaY19+spgEYQnXC/pakC2x/y/YkST+U9FIzbQFoWs+78RFxzPY9kl6RdIakZyLi3cY6A9Conk+99bQyvrMDrWvlohoApw7CDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBI9j88uSbZ3Sjok6bikYxExp4mmADSvVtgr10TE/gaWA6BF7MYDSdQNe0habXuD7SXjvcH2Etvrba+vuS4ANTgiep/Znh4Ru23/k6Q1kv49Il4rvL/3lQGYkIjweNNrbdkjYnf1uE/S85Lm1lkegPb0HHbbk21/44vnkn4gaXNTjQFoVp2j8VMlPW/7i+X8d0T8qZGu0Jh58+YV6wsWLCjW77rrrmL9nHPOKdZfeOGFjrWlS5cW5920aVOxjpPTc9gjYoekyxrsBUCLOPUGJEHYgSQIO5AEYQeSIOxAErWuoDvplXEFXd+tXbu2WL/mmmv61MmJjhw5UqzfeOONxfq6deuabOe00coVdABOHYQdSIKwA0kQdiAJwg4kQdiBJAg7kEQTN5xEy7r9TPWJJ57oWLvssvIPEw8dOlSsP/zww8X6yy+/XKyvXr26Y23mzJnFeRcuXFisc5795LBlB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEk+D37EOh2Hn3VqlXF+tlnn92xtm3btuK8d999d7Fe91z2k08+2fO69+8vjxd68cUXF+sHDhwo1k9X/J4dSI6wA0kQdiAJwg4kQdiBJAg7kARhB5LgPHsfTJkypVgv/eZbkmbPnl2sHzx4sGPthhtuKM77+uuvF+t1le5L3+2e9t3cd999xfpjjz1Wa/mnqp7Ps9t+xvY+25vHTDvX9hrb26rH8r9mAAM3kd34X0u6/ivTHpC0NiIukLS2eg1giHUNe0S8Jumr1x0ukLS8er5c0k3NtgWgab3eg25qROypnn8kaWqnN9peImlJj+sB0JDaN5yMiCgdeIuIEUkjUt4DdMAw6PXU217b0ySpetzXXEsA2tBr2F+StLh6vljSi820A6AtXXfjba+UdLWk82zvkrRU0qOSfmf7TkkfSrqlzSZPdXfccUex3u08ejf3339/x1rb59G72bFjR2vLPv/881tb9umoa9gjYlGH0vca7gVAi7hcFkiCsANJEHYgCcIOJEHYgSQYsrkPZs2aVWv+p556qlZ9kEpDRm/durU474UXXth0O6mxZQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJDjP3gfdbufczZYtW4r148eP11p+yaWXXlqs33bbbcX6vffe27E2adKknnpCb9iyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASnGfvg+3btxfr06dPL9bnz59frL/55psda5999llx3iuvvLJYf/zxx4t1zpWfOtiyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASnGfvg+eee65Yv+qqq4r1a6+9tla9TR9//HGxvmrVqo6122+/vel2UNB1y277Gdv7bG8eM+0R27ttb6z+yld9ABi4iezG/1rS9eNMXxYRl1d/f2y2LQBN6xr2iHhN0oE+9AKgRXUO0N1je1O1mz+l05tsL7G93vb6GusCUFOvYf+lpG9LulzSHkkdfy0RESMRMSci5vS4LgAN6CnsEbE3Io5HxOeSfiVpbrNtAWhaT2G3PW3My4WSNnd6L4Dh0PU8u+2Vkq6WdJ7tXZKWSrra9uWSQtJOST9ur8VT37PPPlusHz16tFh///33i/Vly5Z1rJ111lnFeTdvLv8/3e0agVdffbVYv+KKKzrW6p5nf+WVV2rNn03XsEfEonEmP91CLwBaxOWyQBKEHUiCsANJEHYgCcIOJMFPXPvgyJEjxfrIyEit5c+ePbvW/G3qNuRzHZ9++mlryz4dsWUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQ4z45WXXTRRT3Pu2vXrmJ9w4YNPS87I7bsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE59lRy6xZs4r1W2+9tedlr1y5slg/dOhQz8vOiC07kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBeXYUnXlm+Z/I4sWLi/XJkyd3rH3yySfFeeveTx9f1nXLbnuG7XW2t9h+1/ZPq+nn2l5je1v1OKX9dgH0aiK78cck/TwiviPpXyT9xPZ3JD0gaW1EXCBpbfUawJDqGvaI2BMRb1fPD0l6T9J0SQskLa/etlzSTS31CKABJ/Wd3fZMSd+V9GdJUyNiT1X6SNLUDvMskbSkRo8AGjDho/G2vy7p95J+FhEHx9YiIiTFePNFxEhEzImIObU6BVDLhMJu+2saDfpvIuIP1eS9tqdV9WmS9rXTIoAmdN2Nt21JT0t6LyJ+Mab0kqTFkh6tHl9spUMM1HXXXVesP/TQQz0ve8WKFcX69u3be142TjSR7+xXSvo3Se/Y3lhNe1CjIf+d7TslfSjpllY6BNCIrmGPiNcluUP5e822A6AtXC4LJEHYgSQIO5AEYQeSIOxAEvzEFUU333xza8s+duxYa8vGidiyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASnGdH0Zw59W4wdPjw4Y41bhXdX2zZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJzrOjVfv37+9Y477w/cWWHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSmMj47DMkrZA0VVJIGomI/7T9iKS7JH1cvfXBiPhjW41iMN54441i/ZJLLinWly1b1mQ7qGEiF9Uck/TziHjb9jckbbC9pqoti4jH2msPQFMmMj77Hkl7queHbL8naXrbjQFo1kl9Z7c9U9J3Jf25mnSP7U22n7E9pcM8S2yvt72+XqsA6phw2G1/XdLvJf0sIg5K+qWkb0u6XKNb/sfHmy8iRiJiTkTUu5kZgFomFHbbX9No0H8TEX+QpIjYGxHHI+JzSb+SNLe9NgHU1TXsti3paUnvRcQvxkyfNuZtCyVtbr49AE1xRJTfYM+T9L+S3pH0eTX5QUmLNLoLH5J2SvpxdTCvtKzyygDUFhEeb3rXsDeJsAPt6xR2rqADkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4k0e8hm/dL+nDM6/OqacNoWHsb1r4keutVk739c6dCX3/PfsLK7fXDem+6Ye1tWPuS6K1X/eqN3XggCcIOJDHosI8MeP0lw9rbsPYl0Vuv+tLbQL+zA+ifQW/ZAfQJYQeSGEjYbV9v+6+2P7D9wCB66MT2Ttvv2N446PHpqjH09tnePGbaubbX2N5WPY47xt6AenvE9u7qs9toe/6Aepthe53tLbbftf3TavpAP7tCX3353Pr+nd32GZK2Svq+pF2S3pK0KCK29LWRDmzvlDQnIgZ+AYbtqyQdlrQiIi6tpv2HpAMR8Wj1H+WUiLh/SHp7RNLhQQ/jXY1WNG3sMOOSbpL0Iw3wsyv0dYv68LkNYss+V9IHEbEjIo5K+q2kBQPoY+hFxGuSDnxl8gJJy6vnyzX6j6XvOvQ2FCJiT0S8XT0/JOmLYcYH+tkV+uqLQYR9uqS/jXm9S8M13ntIWm17g+0lg25mHFPHDLP1kaSpg2xmHF2H8e6nrwwzPjSfXS/Dn9fFAboTzYuIKyT9q6SfVLurQylGv4MN07nTCQ3j3S/jDDP+d4P87Hod/ryuQYR9t6QZY15/s5o2FCJid/W4T9LzGr6hqPd+MYJu9bhvwP383TAN4z3eMOMags9ukMOfDyLsb0m6wPa3bE+S9ENJLw2gjxPYnlwdOJHtyZJ+oOEbivolSYur54slvTjAXr5kWIbx7jTMuAb82Q18+POI6PufpPkaPSK/XdJDg+ihQ1/nS/pL9ffuoHuTtFKju3X/p9FjG3dK+kdJayVtk/Q/ks4dot7+S6NDe2/SaLCmDai3eRrdRd8kaWP1N3/Qn12hr758blwuCyTBATogCcIOJEHYgSQIO5AEYQeSIOxAEoQdSOL/AatwDybOkQE8AAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"index=0\n",
"print(f'Label : {labels[index]}')\n",
"print(f'Prediciton : {ps.max(1)[1][index]} | Accuracy : {accuracy[index]}') \n",
"plt.imshow(images[index].numpy().squeeze(), cmap='gray')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2fbf88f6-d9de-4b86-9618-f1942362c929",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment