Skip to content

Instantly share code, notes, and snippets.

@chiragjn
Created June 25, 2017 12:31
Show Gist options
  • Save chiragjn/d042349b6e812d2b3da5d46d8f37d093 to your computer and use it in GitHub Desktop.
Save chiragjn/d042349b6e812d2b3da5d46d8f37d093 to your computer and use it in GitHub Desktop.
Good old MNIST with PyTorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 54,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from __future__ import print_function\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as functional\n",
"import torch.optim as optim\n",
"from torch.autograd import Variable\n",
"import numpy as np\n",
"import os\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.cm as cm\n",
"import pandas as pd\n",
"import math\n",
"\n",
"import torchvision.datasets as dset\n",
"import torchvision.transforms as transforms"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"batch_size = 100\n",
"test_batch_size = 1000\n",
"num_epochs = 100000"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"mnist_train_data = dset.MNIST(root='MNIST_pyt', train=True, transform=transforms.ToTensor(),\n",
" target_transform=None, download=False)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"mnist_test_data = dset.MNIST(root='MNIST_pyt', train=False, transform=transforms.ToTensor(),\n",
" target_transform=None, download=False)"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"train_loader = torch.utils.data.DataLoader(mnist_train_data, batch_size=batch_size, shuffle=True)\n",
"test_loader = torch.utils.data.DataLoader(mnist_test_data, batch_size=batch_size, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"num_epochs = int((num_epochs * 1.0) / len(train_loader)) + 1"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"60000\n",
"784\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztfVuIrNl13re76159O3M0mgkaW07ISyCIIcF6UcBtHIwI\nBgU/KIogSI4RfogSE+tBil7mKPGDrYcBReAHK7LQBAvfwJEciCOHpDF6cCQ7USLHI8uQjGzZmjO3\n06e7rqeqeueha/39/avW/qu6qi/V9a8PNv+u6qr6d3X3t9fa6xpijHA4HOXC1m0vwOFw3Dyc+A5H\nCeHEdzhKCCe+w1FCOPEdjhLCie9wlBArET+E8N4QwrdDCN8JIXz8qhblcDiuF2FZP34IYQvAdwD8\nGIC/AvANAB+IMX5bvc4DBRyOW0KMMVjPryLx3w3gz2KM340xjgD8GoD3JW6ejRdeeCH3eN2Gr29z\n17fOa7uO9RVhFeK/A8Bf0OPvTZ9zOBxrDjfuORwlRGWF9/4lgB+kx89Nn5vBgwcPsvnBwcEKt7x+\nHB4e3vYSCuHrWx7rvDZg9fUdHR3h6OhoodeuYtzbBvCnODfufR/A1wH84xjjy+p1cdl7OByO5RFC\nQEwY95aW+DHGSQjhowC+ivMjw+c16R0Ox3piaYm/8A1c4jsct4Iiie/GPYejhHDiOxwlhBPf4Sgh\nnPgORwnhxHc4SggnvsNRQjjxHY4SwonvcJQQTnyHo4Rw4jscJYQT3+EoIZz4DkcJ4cR3OEoIJ77D\nUUI48R2OEsKJ73CUEE58h6OEcOI7HCWEE9/hKCGc+A5HCeHEdzhKCCe+w1FCOPEdjhLCie9wlBBO\nfIejhHDiOxwlhBPf4SghVmmT7VhDrEOfwhhjtg6Z8+NF3p+ahxAKh76nNVLrWmRt8+7P60g9Xgc4\n8UuOq94oYow4OzvDZDLB2dlZbkwmk5n7WY+LRggBlUoF29vb2N7enpnHGDGZTAoHr09f9Xo0Ube2\ntmbuq9eytbWVvY6vW1tbTnzH7WEe+VbB2dkZxuMxxuMxJpPJzPzs7Cx3Ty3RZchmoedbW1uo1Wqo\nVqszVyHVeDzGkydPMBqNzCFr4rkMWR8TlOfb29vJ+9dqNVQqlZlRrVZRqVSwtbU+J2snfslQpEZf\nBUR6psjHUt+6WloCzyuVCur1OhqNRjbk/dvb2wCA0WiE4XCYG4PBAMPhEE+ePDHHaDTCkydPMJlM\ncqq5vlYqFTQajWwNvJZ6vY56vZ7bCGRtIYRsfeuAlYgfQngFwGMAZwBGMcZ3X8WiHNcDi+ip67IQ\niS9EErLJdTweF561heCpUalU0Gq10Gq1chJaNIEYY3bvfr+Pfr+PXq+Xzfv9fm4j0NfxeJw8l4cQ\nUK1Ws/vzaDab2VU2ACa9HEPWBatK/DMAhzHGR1exGMf1IUX66yD+ZDLJEb/f72MwGGAwGOSIbw3Z\nOERr0MeFarWa20BEktZqNUwmEwAXEr/f76Pb7aLT6WRX3gT0htDv9zEejwGkDXP1eh07OzvmSB0Z\ntre3TfvGbWJV4ge4S3DtYf3DpSzaV0l8kaRMtNFolCO5Jj4T3TqHV6vVjFhCqmq1ikajkST+6ekp\nTk9PcXJykm0CqcES37LONxoN7O3tYX9/H3t7e5kGIWtiA+HW1lZ2xpeNYF2wKvEjgN8LIUwA/HKM\n8XNXsCbHFSJlNdePF3VnzYNW9YX4QqwnT57MEJ4fy6aRGqxCC+nr9XpmPwBs4h8fH+Px48c4OTnJ\nNoJOp5O7np6eYjQaFbrpWq0WDg4O0O12Z0jPEPVezvqWx+A2sSrx3xNj/H4I4WmcbwAvxxi/pl/0\n4MGDbH54eIjDw8MVb7u+WOaPm3rPZc7kKSu59X62llvXVfDkyZMZKcpS9smTJ6a0l8dMfNlAeF6v\n1zPXmZCqVqvlztW9Xg+dTicj9MnJCU5OTnB8fLwQ8cXtZl3Pzs6yzabZbGa2C1mjdmNe5aY6D0dH\nRzg6OlroteGqFhRCeAHAaYzxRfV8XKed7rpxWat5Sg231OCUq0s/x59rGdHkH9Qaq/6tRqMRer1e\nNkTFl/loNCr8brKO1Dm/Xq/j3r17ODg4yAY/jjHi+PjYHCLxe71ebmPix+Px2PTFy7zVauGpp57C\nvXv38NRTT+Xm9+7dw/7+Ptrtdnbul7lca7XaSr/fyyCEgBijGTiwtMQPIbQAbMUYOyGENoAfB/Cp\nZT9v07CI4ayI9JrQlotLB6TInO+v1yEk0pJUrqueRcfjcabei0FPhhjPiox7skbLlSe/k16vl/nG\nhZj8nYXkMkTiyxlfbA5iJJQNT2wG1pB7NZtNNBqN7MqDLfrs01+34B1gNVX/GQC/HUKI08/51Rjj\nV69mWXcbluHMUvcs4lvk5quWgDpARv6Ji0iV8mWLH3sVjMfjnAtPz8Vqbm1KlmZjaTpMeg7TlZ+f\nnJzkiM9n+263m/Pvc2yBqPNCdGtoolubQL1eR7VaRbVazTaOdSI9sALxY4z/D8DzV7iWjUOKfPxz\nDSa3Vnu1tLaGRXwmEFvbU37sVaCNczqIJxXAk9KKrN+PJr08L5ujSHct7U9OTrLjBq+NiS9SnyPu\nhMTiPUhJe5mL3WFTJb6jAEWSi1+j38ORb9aQf9jUYFXasgOMRqMZ37Xlx14WHLJraSd8lLA2gHnJ\nL2dnZzPqPd+Tic+GPRmDwcCM3U9JfCGxGBK1lLc2AY7e08eRdYET/xqQUl+11dyaW5Kd56w2W0P8\n5KlY9+FwmBmzxOjGBq7RaLTyd7eSdHQSTOrYw+QTScnGNfHfTw1XubgBkd5ioWf/vcyHw2FyQ9bE\nZ6+BeA4s0vNVzvgbq+o7imGRfhHis/TSKqmQXhvN2Jimia/HcDjMXF3WuAriWxuf/lkKnH1nZeEJ\n8eV3xZuk2Ci0i0676+Q+VoCOBN2wis+kX2TohB1X9TcAKUu8gC3s+myure7WVYxvmvAW8bXlXFT1\nFOlvgvirYmtrKyOO9qELkUQb0IY9rbrL54kUr1ar2XPymXrOSTiciCNjd3c3c9VxnL6l4ltrXRc4\n8ZdESpJJVhpngvFcG9/0+7VxTBvzLFVfDHMcRWap+6zqS9SZ3GNdIsuEfLVazcyCYxJaQ2wAsnmI\nCt5ut9Hr9TAej2fIzkPuLSTW13a7jf39/Sxkt91uZ8RPkV7ut07kd+IvgSLX02g0wmAwSCaDcKx6\n6nNSRr1VjXsxxiyMltNURRtZB4iqL6TVWXBsNRcy8ogxmqQX7WgymZiEZ62CP0+f85vNJnZ3d7Mh\nkl9U/LtAesCJvxQWsZpLqKqcMWUuIaspqz+fW60jg1VUwnKXpaLjdK76ukp8CYnl7Lfd3d2cZOUz\nuMxjjBnpLePnZDIxDYcs8fmz9Wg0Gmi327khRj1Zh1V5Z93gxL8ktHquz9Ai8bvdbhYfLkEkx8fH\nGAwGScJri7hVKioVwMPurNRRIsZ8AI9oD+tCfJGMInVFWu/t7WUZcc1mMydZ9TzGmCXtWHYSdgem\nwnL5M/V9ZEPiwZqIlvYu8TcImvBMTib+48eP8ejRI7z11lt488038dZbb6Hf7ydj7VPWeOt+Vg25\nefYD6yjBySW3SXwhhpb47XYbu7u7ODg4wFNPPYVWqzVj6edrjDEZ1SjHoKJY/NRns4svZfwT4qeS\nfNaJ/E78JWC56YR4coYWif/o0SO88cYbeO211/D666+j2+0WkrroGGC5Bq1NRNbIV163FQvPgTU3\nDSYEn7NZ4ksiTLvdNgtZcpBMKuxZviOTPPVZqbmszaq5J1qBfCer8u66wIm/BFKk1xJfiP/666/j\n4cOHePXVV3F6elrobptHXH0ter1es17/ZX3s1wFNCG3cE+IfHBzg/v37GfG1RNUus9R35CAdTW5L\nRbeCiCxNQ0fopYp1rguc+CsiJVGtQhQSRGJlnVkBPreJooKTRa9fZHOyPker+pzz3mq1MpV/Z2dn\nbljvvO9VJOH1udzaZOaV0L4LcOIvAR3pJdLfkg568D9GjDF7/zpBS7yUVLU2hXl2Cvmuqc9h6zqH\nzXLkXCrqblHiW+d7eawJbz1ed8PdInDirwD9jwEgqTby80BeFV2nfx5NDC0VLZLpJBptfLSMhylN\ngu+lQ2aF/Pqe87SR1PdL+fOLxrzX3hU48ZeA9c8g/9QpsvNjbahbp38Y+U767Mpx5ylJKMTXBT74\nGCAGtnnEL5L4vE7rM+Z9vyJtpuj4kFL/1+nvtyic+EtCk14k/iKGIyE/Z5qt0z8Pk4+t1RycklKB\nOdeAjwaiCWgSyVyu+t46M46Jz1c9T6Fo0yoiu37vXZb2gBP/0kipf4IilxBLfNYM1knqF5GvVqvl\niG9Zv6XpxXA4zEl6CZXlM36KXEXkr9frS5OeXzeP7POud5XwAif+krCIz/+0ltVXhhgC9T9OCGEt\nDH38PfQZm9NMrSFJMEA+a47LVss9Umr0PFXfIttlCLgowa3PXvQ96w4n/hIQgso/qhCZo8Isaz5f\nNfnX5R8nJXGFeNVqtdAqLmd6dm1y5VpR91OSdlHjnrz+Kr7vZZ6/ynvfJpz4S4L/8EJ6UeF12SbO\nMuMMOzaeyZmYI+hS0sSK6LtsDEBKinFaKveDk2utVkuSXiS+ldwi33VeaS/OuNNRcTIcq8OJf0UQ\n8gjxObtsb28Pg8EAT548Qa1Wm1vltsh4lkrb5Vj0eSj6/Hq9ngXMpOrCWyo+E7+oph/X1bf8/XIv\nyXHnFtN3XcquE5z4S8KSwiFc9HITAu3s7GQlsSaTSUZ8q/T0cDg0s8eYWGI11+8HLizn89ZdZIsQ\n7YRzznd3d7G3t4fd3V3U63XT983GPas0mAzO/7eSaCS/nVNdRStyXB2c+EuAjXB6A9ja2sqVYd7Z\n2cnlydfrdbOstczPzi7KR1tjMpnkOtT0+30AF5V7Fl2//lxuGCHSXVJheTQajblW/VR1IK5ElCot\nJtqGEN8l/vXAib8C2Mgn0BKfSS8SNdVpRkpHpZo5yBlZCnuI9Vy0gEX92CzldU67SHwhvrSokjZV\nzWbTDGCR+dnZ2Uy5MV2CTG92fNV95nWqq+Nq4MRfEhbpgdkzPjdrEGIJ6fnsK4/Pzs6S1V8qlQpG\no1FmWQeQK/5xGV+2rNOqMMPE39/fx71793D//n3cv38fzWZzhvSa+LqRBs/199alyer1+kwdO6uB\nhmM1OPFXAAeosHFPJL4mvWwGuh5fo9FAr9fL2ikX1XwbDofZeVek62AwyKTiImtml5kOktGqvhTA\nuH//Pp5++mm0Wq1kEBMT36oVKH31dDfdRqOBbrebbT4s8VnVd1wdnPhXgJSqz6QXUsmQZhZcvUWI\nz8/p+WAwAICZ3P9KpXJpiZ/y1Vuq/tve9jY8/fTTaLfbSR+8aEFWp1u5DgaDXLMLrqEnm1G73c6V\nq3ZV/+rhxL8kFgnsEPJzxRchv67jpocUi+SS0fxYLONyLu71ejkJafXOAy6yAbWar7vDcHFLbvfM\nz1mEl7kE7eiSYDIGg0HScCm2B3bpuap/PXDiXzGs0NNUpRtL8gKYUe95bG1tZeWi2T4gxrJKpZIk\nnuQEcO04Ue3Fmr6/v4/d3d3sOdl4dHKOrJ+/t/4d8O9CwAFCXORTNsytra3s/uzLd+JfLZz41wAm\ntCY6gBnJy2QQiZ8y8Gnia3eZBNHoszXfl48jfKaXopZS6UYb2FjdnhejbpEeQG7T0aSvVqsIIWTa\nBVv2nfhXCyf+FUO7y3Q2mtYIhNCicscYC8tHM/G5MYbUyN/a2sq5ztj6L3H02uWoS1hLhxjdKGIR\n0qckvfxMNjmL9JJ5x2HCYuBz4l8t5hI/hPB5AD8B4GGM8V3T5+4B+HUA7wTwCoD3xxgfX+M67xSY\n2PKYNwQr803cXTFGs5ijXLe2tmYIz0ExIYQZKz/nwgvROB6fpb2o+pbEtSS+nuvnOHNRtB/peKNJ\nL5l3um2W3ngcq2MRif8FAJ8F8BI99wkA/yXG+OkQwscB/Kvpcw4gJ+14E+DgnGq1asbaA7M5/Tq7\nj1V8Jr1IdE167hfHZOMzvkj7g4ODzLimJb5lWS8ivYCTmMTmoUnP67c65LjEv1rMJX6M8WshhHeq\np98H4Eem8y8COIITH4BNet0xRwxwVrccIF++ywqLFcu+Jr1sHJr0o9EoIw5rG1bTiv39fbM1lKXq\nF/0OeK7LWzPpre+vQ5bdqn/1WPaM//YY40MAiDG+GkJ4+xWu6c5Dn2+1RZ/TaPVVQntTATJbW1um\npNcFLYX0kv7LhTKtRKLd3d0sHp/jDdiqfhlfOlv/OWNQpL+4Li3Xo/W9+fMcq+OqjHuFuaAPHjzI\n5oeHhzg8PLyi264f5knFlH/d+sfnuTyWDD9dnEJGs9nM0n9FRdfSUm8q2iDJWsZlSbeI6u+4Hhwd\nHeHo6Gih1y5L/IchhGdijA9DCM8CeK3oxUx8xzlYBU6dm1OEKyKuRVgeXBXHavghGYB8DpejiWO9\noYXqpz71qeRrFw2ADtMh+AqAD0/nHwLw5csssOywLOJasqZew88x4TX5UyqyHCu400+v18t1+ul2\nu5kBcZ266TquDou4874E4BDA/RDCnwN4AcAvAPjNEMI/BfBdAO+/zkVuKkQC85V/VvS+Iomv02b5\n87TEHw6H6Pf76HQ6uUAZDjPm1GLHZmARq/4HEz/6+1e8llKhiPT8Gj0vkvqLSnzd208yAzlRhhtX\nrkMbbcfVwiP3bhGa9PPO/Pz4smd8gaj6+ozPhkAOq3VVfzPhxL9lMOkXVfXl59YZXz83T9UfDAY5\nd51E9UmlIHEZusTfLDjx1wSX9Y8vcsafp+oPh8OZQhfb29uZf9+Ne5sLJ/4N4yp82kJ0VsvZn69z\n+Tm7T8p3AxdluySrDziPmmu1WrmwYMkjkCi7Is+D427AiX/HIL51iWUX4xtHAHLZK13wcjQa5UJw\n5fVSortareZSfnVOgKTO6qg6WZvjbsCJf8fAQTVWMQuR4pr0nMknlnsxLkqMvxT6lJqAOgNQPleO\nE1xXwEtj3S048e8YhGgi8XWUnRCZyc+lrLUNgGvkjUYjVCoVM9dfyM8punKVdaXcko71gxP/joEl\nvs5rr1TO/5ws7XVzCyGoThaSsb29XajqSwss7mzD4ceOuwEn/h0Dk1zntddqNYQQZkgvBO73+5mE\n56AcroIr+f5Fqr6ONxAtxHF34MS/Y2BVn0k/Ho+zCjYW4fv9PlqtFs7OznK99oT4slmEEGa6/GiJ\nz2sRuwJnFzrWH078OwYhu1yly6yo7SLxU+RnQ56o7dKCq9/vI4QwY9zjGn5MfDbyOenvFpz4dxBs\nyBMVW8g3Ho+zQpXcrVeCceQ9HMjD6bp8TODMvW63m/Xrk3oAYuyTwS2+ZZ18lbn188sEMDlWhxP/\nDoKNadqaLll1XFZLou/EeMfFOThuX0jLGkC328Xjx49Rr9dRqVSy5p5cDITnuv6+NZ+XVOS4fjjx\n7zCsDD9Op5VqPOLrF01BSKaTdXRPPknXFdKHcF7BV3f64SHhv1zJR8+tIZuR42bgxL/j0OQXw59I\nfJH0Qnomt5BesvPkZyzxO51OLspPOtpy+WsuAaZTey2C614BAHJS33H9cOLfUaTUfS3xtctPDIJM\nem5aIfH4o9EI/X4/R/rRaJR1900N3R/QIvp4PM5iEGT93HzEcf1w4t9haPID+TM+k17i+re3t3Pq\nvRBZuvQI8eXnTPrBYJCrvstXmcsmkhrcKkzWLZuR4+bgxL/j0Cm3THyW9FJUQ1psSXttqZ0varlk\n4YlPX7QDeb0QXdpc6SFSnwc3AeW8fs4ylHLbjpuBE3+DkAru4fJZ0pBDimsK8blbjRCfJb1sDJKr\nL4Ndh0L+lPFPJxTJJiUNNZz4Nwcn/h3DIpV5hFDcukvGeDzOuuPKlQffQ2f7SXCQhPfK8zo3gGsB\naPKLwVHWw2uWc74uIOJ+/quHE3/DUFShB0B23pdmmXt7e+j1ehgMBlnwj0TzCTm5xZVsBOwW5HJe\nw+FwRtXn0Ww2Z4p8cN5ArVZL+vktX78TfTk48TcUmvRnZ2e583+j0ciaZQ6Hw8ztJ75/luIyF8Of\nGAc16bmcV2q0Wq1c+K+QXlT9RqOR7J0HXPQFtL6vY3E48TcQ2r0npJfIPZb4QnpRuxuNRiaRJb5f\nn/sl3l+TXrvttFtPynppSc8q/2Qyyd4rV91ajI8DXgNgOTjxNxRMBt2mWojfbrezRB3g/BjQbDbR\n6XTQ6XRyAT2S0SettLjMl6j+VtCOftxqtZKkl01KjgVs8LNKfcnzqbLkjjSc+BsKixya+BzKyw00\npMoOcBHFJ4Y3OYuLFmHF3heF68p95XNYM+EKQnpTECOlpeo74S8PJ/4GIqUOi6rPrjUAOdJLog2Q\nD93ljL6iSrvzknOkJgCn8nLFYFH3+Wfs67fI78S/PJz4GwrL8s0Sn5N2hPTtdjsL/GHS6zh+q6+9\naAF8P+vKxAcuIg0lyEc2GHm9HBGE9KxtyGtc1b88nPgbhqJ/fvbxS6QcB/rU63UAyHLx5SpGvl6v\nlxFPJL8MTX6W6Dw/OzvL+fklsk+Mf1wajDUDbajUx4hFE3x8cziHE79kEPWbjW8cJy+Sf3d3dyal\nt1KpoNPpZEE7OohHwn05YEgPIbZ06e12uzmNQsKJrZJf4/E4Cy/mwYVGrQ3AyT4LJ34JwcY0qaUv\naDQaaLVa2N3dzZXuliNCt9vNVe7l0lziphMXH2sEEifANf4Gg0FW1Qe4sClYBT7l88T4KEcDuQIX\npGcXn7v8bDjxSwhWlbW7bDKZoN1uz9TrF09Ap9OZ8fPLqFQqOekspAWQk/raaCg/5z4ALOnl88bj\nMVqtVhaAZFn+rfBeJ/8s5hI/hPB5AD8B4GGM8V3T514A8BEAr01f9skY4+9e2yodVwa2tIt6zS45\n8c1rSS9+f6m/x4OTfKS2n0T2AReBPjLnUl/ABelZxdeEZ82h2WyapC/K8HPS57GIxP8CgM8CeEk9\n/2KM8cWrX5LjusGx+7wJMDl1y2wJ7+10Ojg9Pc2uQnqrRbd8Fhfh5DM+kCe9aBJsM2CfPhsVLdLL\na3R+gmMWc4kfY/xaCOGdxo98C72DsEgvRjfJ6NNVfNrtdnbuluKbjx8/zs7Y3IdPPhe4IH2lUsmy\n+zjkV6f9VioV9Hq9me68THzL/8/txMTIJ/fn7+24wCpn/I+GEP4JgD8E8LEY4+MrWpPjmsHkl6Ae\nmQsBuRMvn9m5eAdb4zmqTx5zSC+r9XKVzYCDfSSyj9V7TXirgxCn+vJ3k+/rZ/w8liX+LwH41zHG\nGEL4eQAvAvjp1IsfPHiQzQ8PD3F4eLjkbR2rwoq4Y3BgjxTJ4CHnaO2i44Ye7GZjX3sIIcsN4PeI\nTUE+L/Ve1iR0lCC7JrmpJ3+OZfjbJBwdHeHo6Gih14ZFqp5MVf3fEePeoj+b/jx6ZZW7A+sszdd+\nv4/Hjx/j5OQkU/n5cafTQa/XSw7JBmTDHc8rlQp2d3exs7OTu/J8b29vZsjzrVbLrPMn8zKV8J5q\nOuZOt6jED6AzfQjh2Rjjq9OHPwngj1dbomPdkDKOcT4/Z/aJXaDRaGQtuKwrRwTqTr5yVOCiHmJ3\nAGZdglYloH6/n1X8lXXK9+Hzf9mxiDvvSwAOAdwPIfw5gBcA/GgI4XkAZwBeAfAz17hGxw2CVWht\nHQ8hZBJUKvVwGLCU1tL+fR4s/bvdbi4HQHz+Mh8OhznSW649Xf5rMBjkagICF5uSa54XWMSq/0Hj\n6S9cw1ocawRtAJTnhOBa0stm0Gq1cq45brktIbqnp6dZ6i8nBInVX4ivSc/Wfm105Jp/8hq9Pif+\nBVz3ccyADWCSLcc5/Rwiy6q/EE+32ebBpAcuwnQHg0EWOSgWfwC5SD8JDioq9CklwnSpcfb9O5z4\nDgOpCjcsMZlUbJzTSTZ63mw2zXx/ycXnKD/u+CPWeS7bxSG+8vmpqEN53nEOJ77DhHafCbTrTA85\nm6cMeGJsYykuGXpc7IMj8Hg9tVotJ+n154/H49yRhBuHOvEv4MR35GD5ufk5zo23inGMx+OcG42z\n9xqNBkIIWesu6cTL0X8AClVyLvklj+U4IKTnbj+y+bBR0CpSUvT9NxFOfMfSsEjCIb8c4COvlzbb\nUnxDB9jw51oSWkcEViqVzPofQkC9Xs/V7bdy+tljIffb9OAeDSe+YymkSl5J1J9E4Om0X/GvM/lT\nUXV8DwFXAOIsP3ltrVbLeRKs3H5OKNLuyrLAie+4NHTsO8+Lcv2l0CdH0onUt2ropcAuP9kwZEMQ\n4hdV8ZE1iuVf1l0mOPEdK8GS0lauvxgEWeIXqfo6wUagJT6f9SeTSY74ukqQkF/uaWUTlkXqO/Ed\nS8Eipw7zZdILMa0im1axzJTkZ+LzY+7mw6q++P45v18MkeJCLGPuvhPfsTRSklJIrLPwYoxZhV2u\nrMslsyzC6+eE+LwJcG5+kao/Go1yGxZ7KMoEJ75jZaQktQ75jTFmxTK1Zd9S9YuMe3I+DyHkLPXV\najVLCrKMe1IEhEmv8/3LACe+I4eUeq0f62g+TRzr5wDMOnrLEk8X5pC4fyu+gEfqe5YJTnyHiSJy\nF9XM11c9f/z4MR49eoSTk5OsYi+36db3syCS2qrzxzYEHpyfz0cMq9BHGeDEd8ygSFqyRV1X59H1\n8azmGkz8breLfr+fa9xhkZ7nrKJzlR0uDmoRPhU0tGgHnk2DE98xA010vor/XFfPsVR43VBjMpng\n9PQUx8fHuWo9IvG5/RavRYNJz510OE5AN93QcQNlJ78T32FCq/TabSbGMraYswEtNTqdjqnqi7V9\nHumtOns8arXaXPLr3ntlU/MBJ75DwVLteVjpsDpQRtfF58cW8VnV53WkwMQX8guprTO+lvjaPuCx\n+g4HYKr5cl4XInOuPVfcYfKLNsCPuS6/nPFZ1Z9nbbckvkj1eUY9GVZ5sTKRHnDiOxKwJD6XvJJc\n+lQxTa6Iw9qApONKNx7Lqj8P+ozP5C9S8+W1OgtQz8sAJ/6GIUUefXbWLjomOhvp+Hw+mUwwHA6T\npbMlVFZlzOIlAAAR2UlEQVSTXRfL6Ha7uSAbbq89D1ria1WfSa5ddzIcTvyNRconbrnauG4+W+ut\nITX1UoNDZC2VnyvzyNn+svXw2J2nN4AiH33ZpHoRnPgbiEUlumWB10TVQ4ibKqjJnW4tA58uyyU/\nX+R8D+SLZuhuOVYHHzbgOS7gxN9gcMScjBS5iyz1WlUvmosUTw3L6Mdhuyno87iW+FqdL1PrrGXg\nxN9AWFKerfNa5dYEtqS5lfTCaj1b83XQjg7ksTSCIlVfk57HPIlvlfRyOPE3FhbpLXccd7jRHW+s\nuRjjUgE83LVWh+ymovmKJL5V6GORMz7/3FX9WTjxNwyp+HomHxvppJUVD25xpa98htfSmw11qUSd\nVIKPRfxUBVwrXl8b9sqcgLMInPgbCk36VABOt9vN+dVlyPP6Kj731JiXaGOlxy5q1JOrJr028vEo\nc5BOEZz4a4Z5JEi56eSaUqVl3u/3M5Kfnp7OjBTpmfhFn79oEE6qtDUTW8+3trZQq9XQarXQbrcL\nr1zwgzvuOvnP4cRfc1iBN6kEGt1c0hqDwSAn2ZngMueAHMvnrtNvL1vUwipvzVJcZ9yx+l6v13Pd\ncPXY3d3FwcEB9vb2sLOzk20ATH6HE39tURSAk8qDl/N7KnJO+tTxmZ3P9XJNxd9rY5w+yy8CtsZb\n+fQcfmtduUsOX2W+s7ODvb097O3tod1uo9lsol6v50J1HU78tYZ1DrYi7PiMnUqeYQs+96nXc12W\nmjcNPsdrjWPRszqAGUs8X2u1WlaCm688F6Lr0Wg00G630W63sbOzg52dHTSbzVxLbsc55hI/hPAc\ngJcAPAPgDMDnYoz/NoRwD8CvA3gngFcAvD/G+Pga11oazIu8E8nObjUrcaYonp571qey6/Tn62Cb\nZWrlaZWe4+yl7r6W6Fq6y2tkQ9BzPUTVd+JfICyQBvksgGdjjN8MIewA+CMA7wPwUwDejDF+OoTw\ncQD3YoyfMN4fL3sGLDOKyl5J5J2lwnP2W8oiL9lwqcAdLdl1kg5b7TXxF90AxEDHKjzPm80mdnZ2\nclKbH7darZwWYGkFkqHHmXrcxKMsmFYpNne7ub+FGOOrAF6dzjshhJcBPIdz8v/I9GVfBHAEYIb4\njuVQ5I/nszz3o5cutCcnJzlLPT/u9XrJqDvOi08l8lhEv4zU1xKfC2hIW+t2u52d03d3d7P53t5e\njvjyHn7MGXl6uHHvApfa/kIIPwTgeQB/AOCZGOND4HxzCCG8/cpXV1JYlnwddpuKwBM3nRS70EPS\nZlNHBe5QswjBL6vNsXFPl8tqNBqZS253dxf7+/s4ODjIjXa7nb2eJTtLdMsV6H78PBYm/lTN/y0A\nPzuV/PovnvwPePDgQTY/PDzE4eHh5VZ5h7AoEVIEijHm0mT1XCzzVkitFLkQkp+cnJjEtwpkLhIz\nzxASaf94qrqNjEqlkjuT8xA1X5PdIn5qlDnf/ujoCEdHRwu9du4ZHwBCCBUA/xHAf4oxfmb63MsA\nDmOMD6d2gP8WY/xbxntLdcafJxFTEpQl+zw/PBvk9Oh2u2ZgDle8sc7uMhY9p6d88EVRdNLiyiK9\nDFbzeYjK32q1ZopucJWdMhNfY6Uz/hS/AuBPhPRTfAXAhwH8IoAPAfjyKovcNKTCVQE7XVaGlbOe\nyp7TG4Cc8bV/ngtksNFumS42RTXtU9VwNEkto5w8Fl88G/V0NJ5VcMPV+MthEav+ewD8PoBv4Vyd\njwA+CeDrAH4DwA8A+C7O3XnHxvtLKfG1NOfnrK4zMhfLvHbJsUqfSpm1/PV6cNqsDgZaJBAnlRSj\nz+vzRspHL+f8lDuvXq/P+P957ga8CxRJ/IVU/RVvXjrip3zwAGaCX/QYDodm0owMaUChq97wPDXE\nkJe6txj2irC1tWVKcpH0YpkvGpr81twK4hGrfarghlwd57gKVd9xSRT54q1ad3KVdFk5q4uBTq6S\nGmtVwbHq2uthBeBcNuS2qLQ1R89ZQ4iv3XE89Nmdr5xuy2R3df9ycOJfM4pIb/Wc43h6aTf16NGj\nbGji62AeVuVT90hpJYuCyc9lrfmMvru7aw5R14tIr8N5eW4V0HRX3eXhxL8mpMiVqkgjxjYJyOl2\nuzg5OcHx8THeeustvPnmm3jjjTfQ6XRmgm44EGeZNtHLkl7Xs5fzuSTK7O/vz4x2uz3jf+fHElpr\n+eAt1yGvy7E4nPgK80gwj1Sa3PrxvLr1nU4nayop4/j4ONdosqh89WX98JbULPLDc4acjpsXP7yQ\n/ODgIEf6g4MDtFqtmTBafjzPHecEvxo48RcAbwYcOVdUPy51LYqDH4/H6PV6OdWeu8oOBoNcY8rL\nWOMZVm95eWy56HiwO04nyLAfntV7jrHn2HytvjtuDk78AljqMAfYWAUndatoa4MoGhJrz0Ni7K3y\n1Zf1xet69FZtestab/nhi6LvOD2W/fCizvNZ3uvi3Tyc+AlYPngAGXlTLrV5deWLGlmMx+OsQo7O\nrGPi6wCcVazyenDCjBUPb7nY2P3GbjsrnVYCcLTxzkl/s3DiF8AKxhGJz1VqeQwGA7P6rH7MZOfn\nrBZVVp85fYS4bAuqVJRdKq+d1XrL957yyy/ih3d33M3DiW/Ayo5j4jNBxe3G+e58BLCumvB81amy\nRWmzVhWceSiyyks+PBeu5JBZ9sNbmkCqUy1frZr3TvqbhxM/gZQLTNel5zTYk5MTdLvdQvIW9ZWz\njgDaEGi565Y17mnVnt1xYpSzjHS8Ueir7lCbUustd52T/+bgxF8QOnuOA23E3358fJyVoE5F11lS\nX0fWWVF1Mpe18FXPi6AlPkfc6YKV+/v7OX+8FLC00mE5H569BVY1Xcv/7qS/WWwc8TUBrMeWJLck\nqFVJttfr5UJptb/dIj5vAPOk/WXO6hbm+eRFndeGN7lK+qsVfLO/v5+lxaZUejbU6TXwc47bxcYR\nHyju2qJ96zqcNWU4k3mv18tIzoQXVV9KW1nqvm41ZZWzuiw0keblw8+rSy/Vb3jw2Z5z31NhtLwu\nJ/p6YiOJD8zmvMtjOaOztNXzVJcYcbexf53np6enWeqrVaGW1XlN/mVgkUvXstOFKri8FV+1MU+P\nVD68lRPvpF9/bCTxUyo7W+S1tZwls+Vv1372VDcadudZfnpLy1imG01KnZbEmVQSjFjtiwjOTSr4\nGMDn+JQ7zuPo7wY2jvj6vM7k0oY5q4qNVsu15JYqN3pItRsdYGNF76Uq1y6K1NlZDGhitBPprknM\n5ar1VafNat98tVo1m1U66e8WNo74QLpTrHbFMWFlLtF3+nxulbLWXWrEeGfF6GuiWwbEy5LfMuAJ\n8UW6swpv1avXo9lsFrawYj+8FevvpL8b2EjiA7OVbkTqMvGtFtESGjtvpBpapJpOFNXZW9awB8xa\n7TlHXpJmrLx47auXq8TSW354XbraffF3FxtJ/FTRC12LnqvccIUbXdOOxzw/PJ/XU9ei+WVgpc1a\nEp9TZXXmnB71ej0pzd0PvznYOOJrwmuLPfeWY4kvrjkx0KWGzsDTkXVM4qIzr7aCL0ogTUhNzGaz\nORN4I3nx3D46JfVrtZp5jNBzx93GxhJfyKjP6Lr8tGWdTxWp1KmwOlZeUKQGW6SdF7Ou3XVFpak4\nCIfbT8lczvmcLSeJM26ZLw82jvjARSIN94qXsFmrN7yQnnvL6ZFyx+kz+jyCF5GWy05ZkM8typdv\nNBozRjyes5VfZ8x5AE55sHHEZ4nPraO1ip+S+L1eb24sfZG0BzBDcqvhRKrxBFeisUjH8fW6dJWE\n41r16GXo0tVa4vN9nfSbi40mvqj63IFGk15vAL1eL1kPj7PjrJTYGGNO4lsSXZej1uS1ilLwY+kh\nn+pEk/qZbiHNQ9bnan55sHHEB87Jr1V9LpoxT+JbMfta0qfccZr4uroNV6S1ilVw/3aLfNVq1Yyq\n42YVVvac9sXrdtKpVlRO+s3ExhG/SOJzwE7qjD8YDMwKudoXL/eyfPB8xuf010U6zcg5nz+Lr7Va\nzYy1lyAd8cOnSmtZxw+Zu4pfHmw08ZeR+IPBwPS/W75467El8fkML5JdV7qRM3itVst9lr7WarXC\nyDvxwxcNXqcH35QTG0d8RpErTReiEEKuCilBnRqp7Di5VqvVpGV9EeKzH97J7Uhh44hvha3q1NcQ\nwkwiy+7uLvb39zEcDle6v9w7NaxCljyq1WpunfparVZnylVzDL0T3bEINpL4OkNN1HEt6XXLp06n\ng9FotNL9OR9eG/b4vikrvBj3UlZ1bdxLBeC4Vd5RhLnEDyE8B+AlAM8AOAPwyzHGz4YQXgDwEQCv\nTV/6yRjj717bShcES/x6vZ6ztvOZm0nP/eevivhWoUntzmMLvMylhVTKui5x+FbJaivyzg12DguL\nSPwxgJ+LMX4zhLAD4I9CCL83/dmLMcYXr295lwcTXAxlrNoLYQaDAXZ2dmb6yy/SI36R+xd1qkkF\n70gAT1HYrtgQdGUdjvpzsjvmYS7xY4yvAnh1Ou+EEF4G8I7pj9fuP4slPnCh3tdqNYxGo0zSp2re\nXwXxiyzqKU1ArkU95PgYs2j5Kz13OAAgXLL4ww8BOALwtwF8DMCHATwG8IcAPhZjfGy8Jy6bdroM\nUsU0dXpuUcPLVWAZ11gFL9oQFmksoV/Ppa901N+8hB/HZiOEgBij+QdfmPhTNf8IwL+JMX45hPA0\ngDdijDGE8PMA/lqM8aeN990o8ecVu0iNZereFaFI8loq+aIWeOs9y7zfsfkoIv5CVv0QQgXAbwH4\n9zHGLwNAjPF1esnnAPxO6v0PHjzI5oeHhzg8PFzktkvBXViOsuLo6AhHR0cLvXYhiR9CeAnn0v3n\n6Llnp+d/hBD+JYAfjjF+0HjvjUp8h8NxjpVU/RDCewD8PoBvAYjT8UkAHwTwPM5dfK8A+JkY40Pj\n/U58h+MWcCVn/BVu7sR3OG4BRcRP+44cDsfGwonvcJQQTnyHo4Rw4jscJYQT3+EoIZz4DkcJ4cR3\nOEoIJ77DUUI48R2OEsKJ73CUEE58h6OEcOI7HCXEjRN/0Xzh24KvbzWs8/rWeW3Aza7Pia/g61sN\n67y+dV4bsOHEdzgctw8nvsNRQtxIIY5rvYHD4Uji1irwOByO9YOr+g5HCeHEdzhKiBsjfgjhvSGE\nb4cQvhNC+PhN3XdRhBBeCSH8rxDC/wwhfH0N1vP5EMLDEML/pufuhRC+GkL40xDCfw4h7K/Z+l4I\nIXwvhPA/puO9t7i+50II/zWE8H9CCN8KIfyL6fNr8Ts01vfPp8/fyO/wRs74IYQtAN8B8GMA/grA\nNwB8IMb47Wu/+YIIIfxfAH83xvjottcCACGEvwegA+ClGOO7ps/9IoA3Y4yfnm6e92KMn1ij9b0A\n4HQdGqmGEJ4F8Cw3ewXwPgA/hTX4HRas7x/hBn6HNyXx3w3gz2KM340xjgD8Gs6/5DohYI2OPjHG\nrwHQm9D7AHxxOv8igH94o4siJNYHrEkj1RjjqzHGb07nHQAvA3gOa/I7TKzvxprR3tQ/+jsA/AU9\n/h4uvuS6IAL4vRDCN0IIH7ntxSTwdmlaMu1i9PZbXo+Fj4YQvhlC+He3eRRhTJu9Pg/gDwA8s26/\nQ1rff58+de2/w7WRcGuA98QY/w6AfwDgn01V2XXHuvlifwnA34gxPo/z1urroPLv4Lzv489OJav+\nnd3q79BY3438Dm+K+H8J4Afp8XPT59YGMcbvT6+vA/htnB9P1g0PQwjPANkZ8bVbXk8OMcbXqW3S\n5wD88G2ux2r2ijX6Haaa0d7E7/CmiP8NAH8zhPDOEEINwAcAfOWG7j0XIYTWdOdFCKEN4McB/PHt\nrgrA+VmPz3tfAfDh6fxDAL6s33DDyK1vSiTBT+L2f4e/AuBPYoyfoefW6Xc4s76b+h3eWOTe1C3x\nGZxvNp+PMf7Cjdx4AYQQ/jrOpXzEeevwX73t9YUQvgTgEMB9AA8BvADgPwD4TQA/AOC7AN4fYzxe\no/X9KBZopHpD60s1e/06gN/ALf8OV21Gu/L9PWTX4Sgf3LjncJQQTnyHo4Rw4jscJYQT3+EoIZz4\nDkcJ4cR3OEoIJ77DUUI48R2OEuL/AytVNTNhmicLAAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x10ce3f190>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"print(len(mnist_train_data))\n",
"\n",
"input_dim = mnist_train_data[0][0].view(1, -1).size()[1]\n",
"print(input_dim)\n",
"\n",
"hidden_dim = 100\n",
"output_dim = 10\n",
"\n",
"plt.imshow(mnist_train_data[0][0][0].numpy(), cmap = cm.Greys)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 124,
"metadata": {},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self):\n",
" super(Model, self).__init__()\n",
" self.layer_hidden = nn.Linear(in_features=input_dim, out_features=hidden_dim, bias=True)\n",
" self.layer_hidden_sigmoid = nn.Sigmoid()\n",
" self.layer_out = nn.Linear(in_features=hidden_dim, out_features=output_dim, bias=True)\n",
" self.loss = nn.CrossEntropyLoss()\n",
" \n",
" def forward(self, x, y):\n",
" x = self.layer_hidden(x)\n",
" x = self.layer_hidden_sigmoid(x)\n",
" x = self.layer_out(x)\n",
" loss = self.loss(x, y)\n",
" accuracy = 100 * torch.mean((torch.max(x, 1)[1] == y.long()).float())\n",
" return loss, accuracy.detach()"
]
},
{
"cell_type": "code",
"execution_count": 125,
"metadata": {},
"outputs": [],
"source": [
"model = Model()\n",
"optimizer = optim.SGD(model.parameters(), lr=0.005)\n",
"\n",
"def train():\n",
" model.train()\n",
" avg_loss = 0\n",
" avg_accuracy = 0\n",
" for batch_idx, (data, target) in enumerate(train_loader):\n",
" data = data.view(data.size()[0], 1, -1).squeeze(1)\n",
" data, target = Variable(data), Variable(target)\n",
" optimizer.zero_grad()\n",
" loss, accuracy = model(data, target)\n",
" avg_loss += torch.sum(loss)\n",
" avg_accuracy += accuracy\n",
" loss.backward()\n",
" optimizer.step()\n",
" return avg_loss / len(train_loader), avg_accuracy / len(train_loader)\n",
" \n",
"def test():\n",
" model.eval()\n",
" avg_loss = 0\n",
" avg_accuracy = 0\n",
" for batch_idx, (data, target) in enumerate(test_loader):\n",
" data = data.view(data.size()[0], 1, -1).squeeze(1)\n",
" data, target = Variable(data), Variable(target)\n",
" loss, accuracy = model(data, target)\n",
" avg_loss += torch.sum(loss)\n",
" avg_accuracy += accuracy\n",
" return avg_loss / len(test_loader), avg_accuracy / len(test_loader)"
]
},
{
"cell_type": "code",
"execution_count": 126,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training Loss: 2.266255 Accuracy: 24.325001\n",
"Testing Loss: 2.218460 Accuracy: 40.689999\n",
"Training Loss: 2.171618 Accuracy: 49.009998\n",
"Testing Loss: 2.111272 Accuracy: 53.200001\n",
"Training Loss: 2.046243 Accuracy: 57.998333\n",
"Testing Loss: 1.960946 Accuracy: 61.880001\n",
"Training Loss: 1.876979 Accuracy: 63.571667\n",
"Testing Loss: 1.769310 Accuracy: 67.160004\n",
"Training Loss: 1.678596 Accuracy: 67.385002\n",
"Testing Loss: 1.563231 Accuracy: 70.250000\n",
"Training Loss: 1.480505 Accuracy: 71.036667\n",
"Testing Loss: 1.372474 Accuracy: 73.860001\n",
"Training Loss: 1.305467 Accuracy: 74.333336\n",
"Testing Loss: 1.211227 Accuracy: 76.370003\n",
"Training Loss: 1.160578 Accuracy: 76.916664\n",
"Testing Loss: 1.080341 Accuracy: 78.589996\n",
"Training Loss: 1.043451 Accuracy: 78.908333\n",
"Testing Loss: 0.975428 Accuracy: 80.400002\n",
"Training Loss: 0.949120 Accuracy: 80.529999\n",
"Testing Loss: 0.890585 Accuracy: 81.489998\n",
"Training Loss: 0.872441 Accuracy: 81.699997\n",
"Testing Loss: 0.821108 Accuracy: 82.669998\n",
"Training Loss: 0.809431 Accuracy: 82.769997\n",
"Testing Loss: 0.763521 Accuracy: 83.699997\n",
"Training Loss: 0.757025 Accuracy: 83.471664\n",
"Testing Loss: 0.715862 Accuracy: 84.449997\n",
"Training Loss: 0.713034 Accuracy: 84.088333\n",
"Testing Loss: 0.675393 Accuracy: 85.250000\n",
"Training Loss: 0.675661 Accuracy: 84.680000\n",
"Testing Loss: 0.640902 Accuracy: 85.660004\n",
"Training Loss: 0.643608 Accuracy: 85.198334\n",
"Testing Loss: 0.611104 Accuracy: 86.089996\n",
"Training Loss: 0.615867 Accuracy: 85.608330\n",
"Testing Loss: 0.585131 Accuracy: 86.559998\n",
"Training Loss: 0.591616 Accuracy: 86.035004\n",
"Testing Loss: 0.562528 Accuracy: 86.750000\n",
"Training Loss: 0.570268 Accuracy: 86.415001\n",
"Testing Loss: 0.542781 Accuracy: 87.029999\n",
"Training Loss: 0.551381 Accuracy: 86.706665\n",
"Testing Loss: 0.524989 Accuracy: 87.430000\n",
"Training Loss: 0.534503 Accuracy: 86.968330\n",
"Testing Loss: 0.508842 Accuracy: 87.790001\n",
"Training Loss: 0.519369 Accuracy: 87.263336\n",
"Testing Loss: 0.494632 Accuracy: 88.040001\n",
"Training Loss: 0.505707 Accuracy: 87.471664\n",
"Testing Loss: 0.481940 Accuracy: 88.220001\n",
"Training Loss: 0.493364 Accuracy: 87.699997\n",
"Testing Loss: 0.470187 Accuracy: 88.400002\n",
"Training Loss: 0.482095 Accuracy: 87.856667\n",
"Testing Loss: 0.459625 Accuracy: 88.559998\n",
"Training Loss: 0.471838 Accuracy: 88.035004\n",
"Testing Loss: 0.449723 Accuracy: 88.709999\n",
"Training Loss: 0.462415 Accuracy: 88.209999\n",
"Testing Loss: 0.440977 Accuracy: 88.900002\n",
"Training Loss: 0.453709 Accuracy: 88.360001\n",
"Testing Loss: 0.432660 Accuracy: 89.029999\n",
"Training Loss: 0.445762 Accuracy: 88.489998\n",
"Testing Loss: 0.425294 Accuracy: 89.040001\n",
"Training Loss: 0.438364 Accuracy: 88.626663\n",
"Testing Loss: 0.418159 Accuracy: 89.139999\n",
"Training Loss: 0.431497 Accuracy: 88.713333\n",
"Testing Loss: 0.411646 Accuracy: 89.250000\n",
"Training Loss: 0.425114 Accuracy: 88.820000\n",
"Testing Loss: 0.405570 Accuracy: 89.309998\n",
"Training Loss: 0.419153 Accuracy: 88.963333\n",
"Testing Loss: 0.400075 Accuracy: 89.419998\n",
"Training Loss: 0.413556 Accuracy: 89.043335\n",
"Testing Loss: 0.394809 Accuracy: 89.580002\n",
"Training Loss: 0.408312 Accuracy: 89.131668\n",
"Testing Loss: 0.389739 Accuracy: 89.680000\n",
"Training Loss: 0.403451 Accuracy: 89.238335\n",
"Testing Loss: 0.385035 Accuracy: 89.739998\n",
"Training Loss: 0.398792 Accuracy: 89.311668\n",
"Testing Loss: 0.381032 Accuracy: 89.739998\n",
"Training Loss: 0.394433 Accuracy: 89.371666\n",
"Testing Loss: 0.376821 Accuracy: 89.809998\n",
"Training Loss: 0.390297 Accuracy: 89.461670\n",
"Testing Loss: 0.372925 Accuracy: 89.940002\n",
"Training Loss: 0.386393 Accuracy: 89.536667\n",
"Testing Loss: 0.369306 Accuracy: 90.000000\n",
"Training Loss: 0.382667 Accuracy: 89.620003\n",
"Testing Loss: 0.365625 Accuracy: 90.019997\n",
"Training Loss: 0.379072 Accuracy: 89.683334\n",
"Testing Loss: 0.362355 Accuracy: 90.080002\n",
"Training Loss: 0.375780 Accuracy: 89.723335\n",
"Testing Loss: 0.359415 Accuracy: 90.199997\n",
"Training Loss: 0.372569 Accuracy: 89.790001\n",
"Testing Loss: 0.356148 Accuracy: 90.160004\n",
"Training Loss: 0.369505 Accuracy: 89.878334\n",
"Testing Loss: 0.353425 Accuracy: 90.230003\n",
"Training Loss: 0.366551 Accuracy: 89.928337\n",
"Testing Loss: 0.350605 Accuracy: 90.300003\n",
"Training Loss: 0.363740 Accuracy: 89.986664\n",
"Testing Loss: 0.348175 Accuracy: 90.279999\n",
"Training Loss: 0.361077 Accuracy: 90.038330\n",
"Testing Loss: 0.345609 Accuracy: 90.339996\n",
"Training Loss: 0.358483 Accuracy: 90.074997\n",
"Testing Loss: 0.343176 Accuracy: 90.370003\n",
"Training Loss: 0.355956 Accuracy: 90.133331\n",
"Testing Loss: 0.340698 Accuracy: 90.449997\n",
"Training Loss: 0.353595 Accuracy: 90.154999\n",
"Testing Loss: 0.338623 Accuracy: 90.419998\n",
"Training Loss: 0.351268 Accuracy: 90.224998\n",
"Testing Loss: 0.336470 Accuracy: 90.470001\n",
"Training Loss: 0.349028 Accuracy: 90.253334\n",
"Testing Loss: 0.334464 Accuracy: 90.489998\n",
"Training Loss: 0.346867 Accuracy: 90.293335\n",
"Testing Loss: 0.332442 Accuracy: 90.510002\n",
"Training Loss: 0.344768 Accuracy: 90.348335\n",
"Testing Loss: 0.330666 Accuracy: 90.639999\n",
"Training Loss: 0.342760 Accuracy: 90.391670\n",
"Testing Loss: 0.328654 Accuracy: 90.680000\n",
"Training Loss: 0.340802 Accuracy: 90.456665\n",
"Testing Loss: 0.326742 Accuracy: 90.709999\n",
"Training Loss: 0.338914 Accuracy: 90.501663\n",
"Testing Loss: 0.325083 Accuracy: 90.720001\n",
"Training Loss: 0.337050 Accuracy: 90.535004\n",
"Testing Loss: 0.323265 Accuracy: 90.760002\n",
"Training Loss: 0.335297 Accuracy: 90.583336\n",
"Testing Loss: 0.321713 Accuracy: 90.739998\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-126-d39865f22339>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_epochs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0m_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_accuracy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maccuracy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Training Loss: %f Accuracy: %f\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0m_loss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_accuracy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Testing Loss: %f Accuracy: %f\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maccuracy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-125-da10126ff275>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mavg_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mavg_accuracy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mVariable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mVariable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/chirag/venv/lib/python2.7/site-packages/torch/utils/data/dataloader.pyc\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mStopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_indices\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 152\u001b[0;31m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollate_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 153\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 154\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpin_memory_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/chirag/venv/lib/python2.7/site-packages/torchvision/datasets/mnist.pyc\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;31m# doing this so that it is consistent with all other datasets\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0;31m# to return a PIL Image\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfromarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'L'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 53\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/chirag/venv/lib/python2.7/site-packages/PIL/Image.pyc\u001b[0m in \u001b[0;36mfromarray\u001b[0;34m(obj, mode)\u001b[0m\n\u001b[1;32m 2215\u001b[0m \u001b[0mshare\u001b[0m \u001b[0mmemory\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0msupported\u001b[0m \u001b[0mmodes\u001b[0m \u001b[0minclude\u001b[0m \u001b[0;34m\"L\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"RGBX\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"RGBA\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;34m\"CMYK\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2216\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2217\u001b[0;31m \u001b[0mNote\u001b[0m \u001b[0mthat\u001b[0m \u001b[0mthis\u001b[0m \u001b[0mfunction\u001b[0m \u001b[0mdecodes\u001b[0m \u001b[0mpixel\u001b[0m \u001b[0mdata\u001b[0m \u001b[0monly\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mentire\u001b[0m \u001b[0mimages\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2218\u001b[0m \u001b[0mIf\u001b[0m \u001b[0myou\u001b[0m \u001b[0mhave\u001b[0m \u001b[0man\u001b[0m \u001b[0mentire\u001b[0m \u001b[0mimage\u001b[0m \u001b[0mfile\u001b[0m \u001b[0;32min\u001b[0m \u001b[0ma\u001b[0m \u001b[0mstring\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwrap\u001b[0m \u001b[0mit\u001b[0m \u001b[0;32min\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2219\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mBytesIO\u001b[0m\u001b[0;34m**\u001b[0m \u001b[0mobject\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0muse\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0mpy\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m`\u001b[0m\u001b[0;34m~\u001b[0m\u001b[0mPIL\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mImage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m`\u001b[0m \u001b[0mto\u001b[0m \u001b[0mload\u001b[0m \u001b[0mit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/chirag/venv/lib/python2.7/site-packages/PIL/Image.pyc\u001b[0m in \u001b[0;36mfrombuffer\u001b[0;34m(mode, size, data, decoder_name, *args)\u001b[0m\n\u001b[1;32m 2160\u001b[0m \u001b[0mcolor\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImageColor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetcolor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcolor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mImage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_new\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfill\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcolor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2163\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2164\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"for i in range(num_epochs):\n",
" _loss, _accuracy = train()\n",
" loss, accuracy = test()\n",
" print(\"Training Loss: %f Accuracy: %f\" % (_loss.data[0], _accuracy.data[0]))\n",
" print(\"Testing Loss: %f Accuracy: %f\" % (loss.data[0], accuracy.data[0]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment