Skip to content

Instantly share code, notes, and snippets.

@Pabla007
Created December 14, 2019 03:10
Show Gist options
  • Save Pabla007/3d918641d0d025e6b8ac88c8c2c8899d to your computer and use it in GitHub Desktop.
Save Pabla007/3d918641d0d025e6b8ac88c8c2c8899d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
},
"colab": {
"name": "Part 2 - Neural Networks in PyTorch (Exercises).ipynb",
"provenance": []
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "bdWgMNvJnqTF",
"colab_type": "text"
},
"source": [
"# Neural networks with PyTorch\n",
"\n",
"Deep learning networks tend to be massive with dozens or hundreds of layers, that's where the term \"deep\" comes from. You can build one of these deep networks using only weight matrices as we did in the previous notebook, but in general it's very cumbersome and difficult to implement. PyTorch has a nice module `nn` that provides a nice way to efficiently build large neural networks."
]
},
{
"cell_type": "code",
"metadata": {
"id": "SKYVvUltnqTH",
"colab_type": "code",
"outputId": "bf8e7034-b3a2-4010-f571-2f4c447d34db",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 202
}
},
"source": [
"# Import necessary packages\n",
"!wget -c https://raw.githubusercontent.com/udacity/deep-learning-v2-pytorch/master/intro-to-pytorch/helper.py\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format = 'retina'\n",
"\n",
"import numpy as np\n",
"import torch\n",
"\n",
"import helper\n",
"\n",
"import matplotlib.pyplot as plt"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"--2019-12-14 03:04:59-- https://raw.githubusercontent.com/udacity/deep-learning-v2-pytorch/master/intro-to-pytorch/helper.py\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 2813 (2.7K) [text/plain]\n",
"Saving to: ‘helper.py’\n",
"\n",
"\rhelper.py 0%[ ] 0 --.-KB/s \rhelper.py 100%[===================>] 2.75K --.-KB/s in 0s \n",
"\n",
"2019-12-14 03:04:59 (86.8 MB/s) - ‘helper.py’ saved [2813/2813]\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "k5_utV0inqTM",
"colab_type": "text"
},
"source": [
"\n",
"Now we're going to build a larger network that can solve a (formerly) difficult problem, identifying text in an image. Here we'll use the MNIST dataset which consists of greyscale handwritten digits. Each image is 28x28 pixels, you can see a sample below\n",
"\n",
"<img src='assets/mnist.png'>\n",
"\n",
"Our goal is to build a neural network that can take one of these images and predict the digit in the image.\n",
"\n",
"First up, we need to get our dataset. This is provided through the `torchvision` package. The code below will download the MNIST dataset, then create training and test datasets for us. Don't worry too much about the details here, you'll learn more about this later."
]
},
{
"cell_type": "code",
"metadata": {
"id": "u0uPFAVAnqTO",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 252
},
"outputId": "33ad9207-205a-4b1c-fa98-609ab48ea19e"
},
"source": [
"### Run this cell\n",
"\n",
"from torchvision import datasets, transforms\n",
"\n",
"# Define a transform to normalize the data\n",
"transform = transforms.Compose([transforms.ToTensor(),\n",
" transforms.Normalize((0.5,), (0.5,)),\n",
" ])\n",
"\n",
"# Download and load the training data\n",
"trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)\n",
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"\r0it [00:00, ?it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /root/.pytorch/MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"9920512it [00:02, 3402398.54it/s] \n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Extracting /root/.pytorch/MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz to /root/.pytorch/MNIST_data/MNIST/raw\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"\r0it [00:00, ?it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /root/.pytorch/MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"32768it [00:00, 48909.58it/s] \n",
"0it [00:00, ?it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Extracting /root/.pytorch/MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz to /root/.pytorch/MNIST_data/MNIST/raw\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /root/.pytorch/MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"1654784it [00:02, 814901.73it/s] \n",
"0it [00:00, ?it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Extracting /root/.pytorch/MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz to /root/.pytorch/MNIST_data/MNIST/raw\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /root/.pytorch/MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"8192it [00:00, 18497.44it/s] "
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Extracting /root/.pytorch/MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /root/.pytorch/MNIST_data/MNIST/raw\n",
"Processing...\n",
"Done!\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N02rLy7YnqTR",
"colab_type": "text"
},
"source": [
"We have the training data loaded into `trainloader` and we make that an iterator with `iter(trainloader)`. Later, we'll use this to loop through the dataset for training, like\n",
"\n",
"```python\n",
"for image, label in trainloader:\n",
" ## do things with images and labels\n",
"```\n",
"\n",
"You'll notice I created the `trainloader` with a batch size of 64, and `shuffle=True`. The batch size is the number of images we get in one iteration from the data loader and pass through our network, often called a *batch*. And `shuffle=True` tells it to shuffle the dataset every time we start going through the data loader again. But here I'm just grabbing the first batch so we can check out the data. We can see below that `images` is just a tensor with size `(64, 1, 28, 28)`. So, 64 images per batch, 1 color channel, and 28x28 images."
]
},
{
"cell_type": "code",
"metadata": {
"id": "fNOQow70nqTS",
"colab_type": "code",
"outputId": "9242b68b-0fef-4392-9c0d-b67e100a937e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 67
}
},
"source": [
"dataiter = iter(trainloader)\n",
"images, labels = dataiter.next()\n",
"print(type(images))\n",
"print(images.shape)\n",
"print(labels.shape)"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"<class 'torch.Tensor'>\n",
"torch.Size([64, 1, 28, 28])\n",
"torch.Size([64])\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7nbdZqLKnqTU",
"colab_type": "text"
},
"source": [
"This is what one of the images looks like. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "GXBuznALnqTV",
"colab_type": "code",
"outputId": "059064aa-10af-41be-a255-bb02c9980f8b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 265
}
},
"source": [
"plt.imshow(images[1].numpy().squeeze(), cmap='Greys_r');"
],
"execution_count": 4,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfcAAAHwCAYAAAC7cCafAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAWJQAAFiUBSVIk8AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAcR0lEQVR4nO3dfaxtdXkn8O8jKAxUQamVGKa5iCi2\ntjhcS2+hIi/xbZpaLDAaU0obaQq0o1iddFK1Q21NTGN8KfjS1FQSSAYMRJtOUZzIVbDYNr03liFV\nXkREo4K8i4gC9zd/7HXa2+s593L23vesc37780l21tlrref8HhYLvmftvV6qtRYAoB9PGrsBAGC+\nhDsAdEa4A0BnhDsAdEa4A0BnhDsAdEa4A0BnhDsAdEa4A0BnhDsAdEa4A0BnhDsAdGbfsRvYG6rq\na0meluT2kVsBgGltSvJga+3w1RZ2Ge6ZBPszhhcALJReP5a/fewGAGAObp+maNRwr6rDquqvq+pb\nVfXDqrq9qt5fVU8fsy8A2MhG+1i+qo5Icn2Sn0ryN0m+kuTYJG9K8sqqOr61ds9Y/QHARjXmkfuH\nMgn2N7bWTm2t/c/W2slJ3pfk+UneNWJvALBhVWtt7QedHLXfmsl3CUe01nbstOypSb6dpJL8VGvt\n+1P8/m1JjplPtwAwmu2ttc2rLRrrY/mThulndg72JGmtfa+q/j7Jy5NsSfLZlX7JEOLLOWouXQLA\nBjTWx/LPH6Y3r7D8lmH6vDXoBQC6MtaR+0HD9IEVli/NP3h3v2Sljyp8LA/AIuv1OncAWFhjhfvS\nkflBKyxfmn//GvQCAF0ZK9xvGqYrfad+5DBd6Tt5AGAFY4X71mH68qr6Dz0Ml8Idn+ThJP+w1o0B\nwEY3Sri31r6a5DOZPPHm93ZZ/CdJDkxyyTTXuAPAohvzqXDnZXL72b+oqlOSfDnJL2ZyDfzNSd42\nYm8AsGGNdrb8cPT+4iQXZxLqb0lyRJIPJNnivvIAMJ1Rn+feWvtGkt8eswcA6I3r3AGgM8IdADoj\n3AGgM8IdADoj3AGgM8IdADoj3AGgM8IdADoj3AGgM8IdADoj3AGgM8IdADoj3AGgM8IdADoj3AGg\nM8IdADoj3AGgM8IdADoj3AGgM8IdADoj3AGgM8IdADoj3AGgM8IdADoj3AGgM8IdADoj3AGgM8Id\nADoj3AGgM8IdADoj3AGgM8IdADoj3AGgM8IdADoj3AGgM8IdADoj3AGgM8IdADoj3AGgM8IdADoj\n3AGgM8IdADoj3AGgM8IdADoj3AGgM8IdADoj3AGgM8IdADoj3AGgM8IdADoj3AGgM8IdADoj3AGg\nM8IdADoj3AGgM8IdADoj3AGgM8IdADoj3AGgM8IdADoj3AGgM8IdADozWrhX1e1V1VZ4fWesvgBg\no9t35PEfSPL+ZeY/tNaNAEAvxg73+1trF4zcAwB0xXfuANCZsY/c96uq30jy00m+n+SGJNe21h4f\nty0A2LjGDvdDk1yyy7yvVdVvt9Y+v6fiqtq2wqKjZu4MADaoMT+W/1iSUzIJ+AOT/FySv0yyKcmn\nquro8VoDgI2rWmtj9/AfVNV7krwlySdba6+Z8ndsS3LMXBsDgLW3vbW2ebVF6/GEuo8M0xNG7QIA\nNqj1GO7fHaYHjtoFAGxQ6zHctwzT20btAgA2qFHCvapeUFU/dmReVZuSXDS8vXQtewKAXox1Kdxr\nk7ylqq5N8vUk30tyRJJfSbJ/kquSvGek3gBgQxsr3LcmeX6S/5Lk+Ey+X78/yRcyue79krbeTuMH\ngA1ilHAfblCzx5vUAACrtx5PqAMAZiDcAaAzwh0AOiPcAaAzwh0AOiPcAaAzwh0AOiPcAaAzwh0A\nOiPcAaAzwh0AOiPcAaAzwh0AOiPcAaAzwh0AOjPK89yB9eEnfuInpq795V/+5ZnGftOb3jRT/Ykn\nnjh17X777TfT2FU1de3NN98809gvfOELp6599NFHZxqbjcOROwB0RrgDQGeEOwB0RrgDQGeEOwB0\nRrgDQGeEOwB0RrgDQGeEOwB0RrgDQGeEOwB0RrgDQGeEOwB0RrgDQGc88hU2sC1btsxUf9FFF01d\ne8wxx8w09kbWWpu69sgjj5xp7Gc84xlT1955550zjc3G4cgdADoj3AGgM8IdADoj3AGgM8IdADoj\n3AGgM8IdADoj3AGgM8IdADoj3AGgM8IdADoj3AGgM8IdADoj3AGgM8IdADrjee4wsj//8z+fuva8\n886baewDDjhg6tpZnmmeJDfddNNM9du3b5+69vWvf/1MY89i27ZtM9U/8MADc+qEnjlyB4DOCHcA\n6IxwB4DOCHcA6IxwB4DOCHcA6IxwB4DOCHcA6IxwB4DOCHcA6IxwB4DOCHcA6IxwB4DOCHcA6IxH\nvkKSZz/72VPXbt26daaxn/vc505dW1UzjT3L40NPO+20mca+7rrrZqr/6le/OlP9LK688sqpa3/r\nt35rprEfeeSRmepZDI7cAaAzcwn3qjq9qi6squuq6sGqalV16R5qjquqq6rq3qr6QVXdUFXnV9U+\n8+gJABbVvD6Wf3uSo5M8lOSbSY7a3cpV9WtJrkzySJLLk9yb5FeTvC/J8UnOmFNfALBw5vWx/JuT\nPC/J05Kcu7sVq+ppSf4qyeNJTmytvaG19j+SvCjJF5OcXlWvm1NfALBw5hLurbWtrbVbWmvtCax+\nepJnJrmstfbPO/2ORzL5BCDZwx8IAMDKxjih7uRh+ullll2b5OEkx1XVfmvXEgD0Y4xL4Z4/TG/e\ndUFr7bGq+lqSn03ynCRf3t0vqqptKyza7Xf+ANCzMY7cDxqmK11guzT/4DXoBQC6s6FvYtNa27zc\n/OGI/pg1bgcA1oUxjtyXjswPWmH50vz716AXAOjOGOF+0zB93q4LqmrfJIcneSzJbWvZFAD0Yoxw\nv2aYvnKZZSckOSDJ9a21H65dSwDQjzHC/Yokdyd5XVW9eGlmVe2f5M+Gtx8eoS8A6MJcTqirqlOT\nnDq8PXSY/lJVXTz8fHdr7a1J0lp7sKp+J5OQ/1xVXZbJ7WdfncllcldkcktaAGAK8zpb/kVJztpl\n3nOGV5J8Pclblxa01j5ZVS9N8rYkpyXZP8mtSf4gyV88wTvdAQDLmEu4t9YuSHLBKmv+Psl/ncf4\nMKtZni1++OGHz7GT1bnmmmv2vNJunHfeeVPX3nrrrTONfdVVV81Uf9hhh81UP4u77rpr6tpjjz12\nprG3bt06Uz2LwfPcAaAzwh0AOiPcAaAzwh0AOiPcAaAzwh0AOiPcAaAzwh0AOiPcAaAzwh0AOiPc\nAaAzwh0AOiPcAaAzwh0AOjOv57nDzPbZZ5+pay+88MKZxt60adPUtY8++uhMY5999tlT115++eUz\njf2jH/1o6tpZ/n0lyc/8zM/MVD+mc889d+razZs3zzT2li1bZqpnMThyB4DOCHcA6IxwB4DOCHcA\n6IxwB4DOCHcA6IxwB4DOCHcA6IxwB4DOCHcA6IxwB4DOCHcA6IxwB4DOCHcA6IxwB4DOeJ4768Yz\nn/nMqWvPOeecOXayOm9729tmqr/kkkvm1MnqHXLIIVPXvva1r51p7IMPPnimemBljtwBoDPCHQA6\nI9wBoDPCHQA6I9wBoDPCHQA6I9wBoDPCHQA6I9wBoDPCHQA6I9wBoDPCHQA6I9wBoDPCHQA6U621\nsXuYu6raluSYsftgdQ499NCpa7/1rW/NsZPVefTRR2eq37Fjx5w6Wb2qmrr2KU95yhw72Vhm+Xf2\n+te/fqaxP/7xj89Uz4azvbW2ebVFjtwBoDPCHQA6I9wBoDPCHQA6I9wBoDPCHQA6I9wBoDPCHQA6\nI9wBoDPCHQA6I9wBoDPCHQA6I9wBoDPCHQA6I9wBoDP7jt0ALLnnnnumrr366qtnGvsVr3jF1LVP\nfvKTZxp7o9q+fftM9fvss89M9UcfffRM9bP41Kc+NXWt57GzFhy5A0Bn5hLuVXV6VV1YVddV1YNV\n1arq0hXW3TQsX+l12Tx6AoBFNa+P5d+e5OgkDyX5ZpKjnkDNvyT55DLzb5xTTwCwkOYV7m/OJNRv\nTfLSJFufQM2XWmsXzGl8AGAwl3Bvrf1bmFfVPH4lADClMc+Wf3ZV/W6SQ5Lck+SLrbUbVvMLqmrb\nCoueyNcCANClMcP9ZcPr31TV55Kc1Vq7Y5SOAKADY4T7w0n+NJOT6W4b5v18kguSnJTks1X1otba\n9/f0i1prm5ebPxzRHzOXbgFgg1nz69xba3e11v64tba9tXb/8Lo2ycuT/GOS5yY5e637AoBerJub\n2LTWHkvy0eHtCWP2AgAb2boJ98F3h+mBo3YBABvYegv3LcP0tt2uBQCsaM3DvaqOqaofG7eqTsnk\nZjhJsuytawGAPZvL2fJVdWqSU4e3hw7TX6qqi4ef726tvXX4+b1Jjqyq6zO5q10yOVv+5OHnd7TW\nrp9HXwCwiOZ1KdyLkpy1y7znDK8k+XqSpXC/JMlrkvxCklcleXKSO5N8PMlFrbXr5tQTACykaq2N\n3cPcuc598TzpSbN9w3T++edPXXvsscfONPaNN07/rKSbbrppprE/8YlPTF27Y8eOmcb+2Mc+NlP9\nmWeeOXXtnXfeOdPYL3jBC6auvf/++2cam4WzfaV7uuzOejuhDgCYkXAHgM4IdwDojHAHgM4IdwDo\njHAHgM4IdwDojHAHgM4IdwDojHAHgM4IdwDojHAHgM4IdwDojHAHgM7M63nuMKpZHz/63ve+d06d\nLI6XvOQlM9XP8sjWWX3wgx+cqd5jW1nvHLkDQGeEOwB0RrgDQGeEOwB0RrgDQGeEOwB0RrgDQGeE\nOwB0RrgDQGeEOwB0RrgDQGeEOwB0RrgDQGeEOwB0RrgDQGc8zx0W2FOf+tSpa6+88so5drJ69913\n39S1l19++Rw7gfXHkTsAdEa4A0BnhDsAdEa4A0BnhDsAdEa4A0BnhDsAdEa4A0BnhDsAdEa4A0Bn\nhDsAdEa4A0BnhDsAdEa4A0BnPPIVFthxxx03de1P/uRPzjT2jh07Zqo/44wzpq695ZZbZhob1jtH\n7gDQGeEOAJ0R7gDQGeEOAJ0R7gDQGeEOAJ0R7gDQGeEOAJ0R7gDQGeEOAJ0R7gDQGeEOAJ0R7gDQ\nGeEOAJ0R7gDQGc9zhwX2rne9a7Sxr7/++pnqr7nmmjl1Av2Z+ci9qg6pqrOr6hNVdWtV/aCqHqiq\nL1TVG6pq2TGq6riquqqq7h1qbqiq86tqn1l7AoBFNo8j9zOSfDjJt5NsTXJHkmcl+fUkH03yqqo6\no7XWlgqq6teSXJnkkSSXJ7k3ya8meV+S44ffCQBMYR7hfnOSVyf5u9bajqWZVfVHSf4pyWmZBP2V\nw/ynJfmrJI8nObG19s/D/HckuSbJ6VX1utbaZXPoDQAWzswfy7fWrmmt/e3OwT7M/06SjwxvT9xp\n0elJnpnksqVgH9Z/JMnbh7fnztoXACyqvX22/KPD9LGd5p08TD+9zPrXJnk4yXFVtd/ebAwAerXX\nzpavqn2T/Obwducgf/4wvXnXmtbaY1X1tSQ/m+Q5Sb68hzG2rbDoqNV1CwD92JtH7u9O8sIkV7XW\nrt5p/kHD9IEV6pbmH7y3GgOAnu2VI/eqemOStyT5SpIz98YYSdJa27zC+NuSHLO3xgWA9WzuR+5V\n9ftJPpDkX5Oc1Fq7d5dVlo7MD8rylubfP+/eAGARzDXcq+r8JBcmuTGTYP/OMqvdNEyft0z9vkkO\nz+QEvNvm2RsALIq5hXtV/WEmN6H5UibBftcKqy7dM/KVyyw7IckBSa5vrf1wXr0BwCKZS7gPN6B5\nd5JtSU5prd29m9WvSHJ3ktdV1Yt3+h37J/mz4e2H59EXACyimU+oq6qzkrwzkzvOXZfkjVW162q3\nt9YuTpLW2oNV9TuZhPznquqyTG4/++pMLpO7IpNb0gIAU5jH2fKHD9N9kpy/wjqfT3Lx0pvW2ier\n6qVJ3pbJ7Wn3T3Jrkj9I8hc734ceAFid6jFHXQrHojjnnHNmqv/Qhz40de2OHTv2vNJuHHzwbLey\neOihh2aqhw1i+0qXfe/O3r79LACwxoQ7AHRGuANAZ4Q7AHRGuANAZ4Q7AHRGuANAZ4Q7AHRGuANA\nZ4Q7AHRGuANAZ4Q7AHRGuANAZ4Q7AHRGuANAZ/YduwFgegcccMBoY993330z1XseO+w9jtwBoDPC\nHQA6I9wBoDPCHQA6I9wBoDPCHQA6I9wBoDPCHQA6I9wBoDPCHQA6I9wBoDPCHQA6I9wBoDPCHQA6\n45GvMLKnPOUpU9eee+65M439+OOPT137zne+c6axgb3HkTsAdEa4A0BnhDsAdEa4A0BnhDsAdEa4\nA0BnhDsAdEa4A0BnhDsAdEa4A0BnhDsAdEa4A0BnhDsAdEa4A0BnhDsAdMbz3GFkszwX/Ygjjphp\n7G984xtT11544YUzjQ3sPY7cAaAzwh0AOiPcAaAzwh0AOiPcAaAzwh0AOiPcAaAzwh0AOiPcAaAz\nwh0AOiPcAaAzwh0AOiPcAaAzwh0AOuORrzCyww47bOraRx55ZKax3/Oe98xUD6xPjtwBoDMzh3tV\nHVJVZ1fVJ6rq1qr6QVU9UFVfqKo3VNWTdll/U1W13bwum7UnAFhk8/hY/owkH07y7SRbk9yR5FlJ\nfj3JR5O8qqrOaK21Xer+Jcknl/l9N86hJwBYWPMI95uTvDrJ37XWdizNrKo/SvJPSU7LJOiv3KXu\nS621C+YwPgCwk5k/lm+tXdNa+9udg32Y/50kHxnenjjrOADAE7O3z5Z/dJg+tsyyZ1fV7yY5JMk9\nSb7YWrthL/cDAN3ba+FeVfsm+c3h7aeXWeVlw2vnms8lOau1dscTHGPbCouOeoJtAkB39ualcO9O\n8sIkV7XWrt5p/sNJ/jTJ5iRPH14vzeRkvBOTfLaqDtyLfQFA1/bKkXtVvTHJW5J8JcmZOy9rrd2V\n5I93Kbm2ql6e5AtJfjHJ2Uk+sKdxWmubVxh/W5JjVt85AGx8cz9yr6rfzySY/zXJSa21e59IXWvt\nsUwunUuSE+bdFwAsirmGe1Wdn+TCTK5VP2k4Y341vjtMfSwPAFOaW7hX1R8meV+SL2US7HdN8Wu2\nDNPb5tUXACyauYR7Vb0jkxPotiU5pbV2927WPWbXW9IO809J8ubh7aXz6AsAFtHMJ9RV1VlJ3pnk\n8STXJXljVe262u2ttYuHn9+b5Miquj7JN4d5P5/k5OHnd7TWrp+1LwBYVPM4W/7wYbpPkvNXWOfz\nSS4efr4kyWuS/EKSVyV5cpI7k3w8yUWttevm0BMALKz68ee5bHwuhWMjufTS6b+FetaznjXT2C97\n2cv2vBIwpu0rXfa9O57nDgCdEe4A0BnhDgCdEe4A0BnhDgCdEe4A0BnhDgCdEe4A0BnhDgCdEe4A\n0BnhDgCdEe4A0BnhDgCdEe4A0BmPfAWA9csjXwEA4Q4A3RHuANAZ4Q4AnRHuANAZ4Q4AnRHuANAZ\n4Q4AnRHuANAZ4Q4AnRHuANAZ4Q4AnRHuANAZ4Q4Anek13DeN3QAAzMGmaYr2nXMT68WDw/T2FZYf\nNUy/svdb6YZtNh3bbTq22+rZZtNZz9ttU/49z1alWmvzbWUDqKptSdJa2zx2LxuFbTYd2206ttvq\n2WbT6XW79fqxPAAsLOEOAJ0R7gDQGeEOAJ0R7gDQmYU8Wx4AeubIHQA6I9wBoDPCHQA6I9wBoDPC\nHQA6I9wBoDPCHQA6s1DhXlWHVdVfV9W3quqHVXV7Vb2/qp4+dm/r1bCN2gqv74zd31iq6vSqurCq\nrquqB4ftcekeao6rqquq6t6q+kFV3VBV51fVPmvV99hWs92qatNu9r1WVZetdf9jqKpDqursqvpE\nVd067DsPVNUXquoNVbXs/8cXfX9b7XbrbX/r9XnuP6aqjkhyfZKfSvI3mTy799gkb0ryyqo6vrV2\nz4gtrmcPJHn/MvMfWutG1pG3Jzk6k23wzfz7M6GXVVW/luTKJI8kuTzJvUl+Ncn7khyf5Iy92ew6\nsqrtNviXJJ9cZv6Nc+xrPTsjyYeTfDvJ1iR3JHlWkl9P8tEkr6qqM9pOdySzvyWZYrsN+tjfWmsL\n8UpydZKW5L/vMv+9w/yPjN3jenwluT3J7WP3sd5eSU5KcmSSSnLisA9dusK6T0tyV5IfJnnxTvP3\nz+QPzpbkdWP/M63D7bZpWH7x2H2PvM1OziSYn7TL/EMzCayW5LSd5tvfpttuXe1vC/Gx/HDU/vJM\nguqDuyz+X0m+n+TMqjpwjVtjg2qtbW2t3dKG/yvswelJnpnkstbaP+/0Ox7J5Eg2Sc7dC22uO6vc\nbiRprV3TWvvb1tqOXeZ/J8lHhrcn7rTI/paptltXFuVj+ZOG6WeW+Rf9var6+0zCf0uSz651cxvA\nflX1G0l+OpM/hG5Icm1r7fFx29owTh6mn15m2bVJHk5yXFXt11r74dq1tWE8u6p+N8khSe5J8sXW\n2g0j97RePDpMH9tpnv1tz5bbbku62N8WJdyfP0xvXmH5LZmE+/Mi3JdzaJJLdpn3tar67dba58do\naINZcf9rrT1WVV9L8rNJnpPky2vZ2AbxsuH1b6rqc0nOaq3dMUpH60BV7ZvkN4e3Owe5/W03drPd\nlnSxvy3Ex/JJDhqmD6ywfGn+wWvQy0bzsSSnZBLwByb5uSR/mcn3U5+qqqPHa23DsP9N5+Ekf5pk\nc5KnD6+XZnJy1IlJPrvgX6W9O8kLk1zVWrt6p/n2t91babt1tb8tSrgzpdbanwzfXd3ZWnu4tXZj\na+2cTE5E/E9JLhi3Q3rVWrurtfbHrbXtrbX7h9e1mXzK9o9Jnpvk7HG7HEdVvTHJWzK56ufMkdvZ\nMHa33Xrb3xYl3Jf+Uj1oheVL8+9fg156sXRCygmjdrEx2P/mqLX2WCaXMiULuP9V1e8n+UCSf01y\nUmvt3l1Wsb8t4wlst2Vt1P1tUcL9pmH6vBWWHzlMV/pOnh/33WG6YT6mGtGK+9/w/d/hmZzYc9ta\nNrXBLeT+V1XnJ7kwk2uuTxrO/N6V/W0XT3C77c6G298WJdy3DtOXL3NXoqdmclOHh5P8w1o3toFt\nGaYL8z+IGVwzTF+5zLITkhyQ5PoFPnN5Ggu3/1XVH2ZyE5ovZRJQd62wqv1tJ6vYbruz4fa3hQj3\n1tpXk3wmk5PAfm+XxX+SyV9jl7TWvr/Gra1rVfWC5U4gqapNSS4a3u72lqskSa5IcneS11XVi5dm\nVtX+Sf5sePvhMRpbz6rqmOVurVpVpyR58/B2Ifa/qnpHJieCbUtySmvt7t2sbn8brGa79ba/1aLc\nS2KZ289+OckvZnIN/M1JjmtuP/sfVNUFmZx8cm2Sryf5XpIjkvxKJne7uirJa1prPxqrx7FU1alJ\nTh3eHprkFZn8VX/dMO/u1tpbd1n/ikxuB3pZJrcDfXUmly1dkeS/LcKNXVaz3YbLj47M5L/bbw7L\nfz7/fh33O1prS2HVrao6K8nFSR7P5KPl5c6Cv721dvFONQu/v612u3W3v419i7y1fCX5z5lc2vXt\nJD/KJLDen+TpY/e2Hl+ZXAbyvzM5s/T+TG788N0k/zeT60Rr7B5H3DYXZHKrypVety9Tc3wmfxDd\nl+QHSf5fJkcE+4z9z7Met1uSNyT5P5ncWfKhTG6nekcm90p/ydj/LOtom7Ukn7O/zbbdetvfFubI\nHQAWxUJ85w4Ai0S4A0BnhDsAdEa4A0BnhDsAdEa4A0BnhDsAdEa4A0BnhDsAdEa4A0BnhDsAdEa4\nA0BnhDsAdEa4A0BnhDsAdEa4A0BnhDsAdOb/A3iKnxlC8nJBAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"image/png": {
"width": 251,
"height": 248
}
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HLtsipXjnqTX",
"colab_type": "text"
},
"source": [
"First, let's try to build a simple network for this dataset using weight matrices and matrix multiplications. Then, we'll see how to do it using PyTorch's `nn` module which provides a much more convenient and powerful method for defining network architectures.\n",
"\n",
"The networks you've seen so far are called *fully-connected* or *dense* networks. Each unit in one layer is connected to each unit in the next layer. In fully-connected networks, the input to each layer must be a one-dimensional vector (which can be stacked into a 2D tensor as a batch of multiple examples). However, our images are 28x28 2D tensors, so we need to convert them into 1D vectors. Thinking about sizes, we need to convert the batch of images with shape `(64, 1, 28, 28)` to a have a shape of `(64, 784)`, 784 is 28 times 28. This is typically called *flattening*, we flattened the 2D images into 1D vectors.\n",
"\n",
"Previously you built a network with one output unit. Here we need 10 output units, one for each digit. We want our network to predict the digit shown in an image, so what we'll do is calculate probabilities that the image is of any one digit or class. This ends up being a discrete probability distribution over the classes (digits) that tells us the most likely class for the image. That means we need 10 output units for the 10 classes (digits). We'll see how to convert the network output into a probability distribution next.\n",
"\n",
"> **Exercise:** Flatten the batch of images `images`. Then build a multi-layer network with 784 input units, 256 hidden units, and 10 output units using random tensors for the weights and biases. For now, use a sigmoid activation for the hidden layer. Leave the output layer without an activation, we'll add one that gives us a probability distribution next."
]
},
{
"cell_type": "code",
"metadata": {
"id": "MSkYKxuNnqTY",
"colab_type": "code",
"outputId": "2dd96ea9-874e-4928-b9db-91078edf097d",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"source": [
"## Your solution\n",
"\n",
"def activation(x):\n",
" return 1/(1+torch.exp(-x))\n",
"\n",
"# Flatten the input images\n",
"inputs = images.view(images.shape[0], -1)\n",
"\n",
"# Create parameters\n",
"w1 = torch.randn(784, 256)\n",
"b1 = torch.randn(256)\n",
"\n",
"w2 = torch.randn(256, 10)\n",
"b2 = torch.randn(10)\n",
"\n",
"h = activation(torch.mm(inputs, w1) + b1)\n",
"\n",
"out = torch.mm(h, w2) + b2\n",
"# output of your network, should have shape (64,10)\n",
"print(out)"
],
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"text": [
"tensor([[-8.9814e+00, -1.6648e+01, -9.0729e+00, 8.1858e+00, -3.3657e+00,\n",
" -7.0763e+00, 9.3361e+00, -9.1638e+00, -6.0795e+00, -9.8519e+00],\n",
" [-3.1033e+00, -7.3122e+00, 1.4950e+00, 5.9985e-01, 1.1074e+00,\n",
" 3.7468e-01, 1.8607e+01, -1.0438e+01, -2.2650e+00, -2.8435e+00],\n",
" [-8.2143e+00, -1.4639e+01, -6.3387e+00, -2.3759e+00, 3.3334e+00,\n",
" 3.1429e+00, 5.7600e+00, -1.7282e+01, -8.7847e+00, -1.0255e+01],\n",
" [-1.0439e+00, -8.5052e+00, -9.0302e+00, -9.3048e-01, 6.5602e+00,\n",
" -7.5514e-01, 1.0762e+01, -1.4771e+01, 1.0293e+01, 2.5184e+00],\n",
" [-9.1020e-02, -1.8599e+01, -2.6547e+00, 3.2984e-01, 7.0065e+00,\n",
" 3.5095e+00, 5.5334e+00, -1.8461e+01, 1.1824e-01, -3.2717e+00],\n",
" [-2.7580e+00, -1.8433e+01, -6.4073e+00, 2.0300e+00, 2.8160e+00,\n",
" -7.2223e+00, 6.5661e+00, -5.9189e+00, -5.6133e+00, -1.2308e+01],\n",
" [-1.3840e+01, -4.0733e+00, -9.2137e+00, -4.3473e-01, 1.5257e+00,\n",
" 6.0447e+00, 2.6190e+01, -3.4640e+00, -4.6874e+00, -2.9113e+00],\n",
" [-7.2633e+00, -1.5883e+01, 5.4153e-01, -3.5073e+00, 4.8298e+00,\n",
" 1.0189e+01, 1.1245e+01, -1.5854e+01, 1.3878e+00, -1.3965e+00],\n",
" [-1.3144e+01, -2.3189e+01, -5.8868e+00, -1.1633e+01, -5.3399e+00,\n",
" 5.6378e+00, 4.2772e+00, -1.5797e+01, -3.7575e+00, -4.2485e+00],\n",
" [ 5.6529e+00, -1.5585e+01, -4.7061e-01, -1.0905e+00, 3.1402e+00,\n",
" 1.9380e+01, 9.8672e+00, -1.7743e+00, 1.7163e+00, -1.0045e+01],\n",
" [ 1.4043e+00, -1.8160e+01, 3.6816e+00, 1.0366e+01, 1.5369e+01,\n",
" 7.1249e+00, 5.4707e+00, -1.0047e+01, 1.1698e+00, -1.0802e+01],\n",
" [-2.2811e+00, -9.7534e+00, 4.8776e-01, 3.0159e+00, 9.7267e+00,\n",
" 2.9533e+00, 3.4848e+00, -1.9195e+01, 1.0160e+01, -2.2975e+00],\n",
" [-6.5838e+00, -1.0244e+01, -1.1836e+01, -4.3827e+00, 8.5692e+00,\n",
" 5.1098e+00, 7.6429e+00, -3.3090e+00, 9.0607e+00, -3.2510e-01],\n",
" [-9.3500e+00, -1.5793e+01, -5.0463e+00, -1.2946e+00, 3.4168e+00,\n",
" 1.1405e+00, 1.5416e+01, -1.7247e+01, 3.3923e-02, -7.4388e+00],\n",
" [-1.8152e+00, -1.3309e+01, 1.4454e+00, 5.2942e+00, 1.2195e+01,\n",
" -1.3049e+00, 7.3132e+00, -1.1939e+01, -4.5278e+00, -1.6379e+01],\n",
" [-2.4646e-01, -1.7254e+01, -2.5953e+00, 2.1574e+00, 4.3040e+00,\n",
" 4.6031e+00, 8.6153e+00, -1.1789e+01, 4.4378e+00, -9.4622e+00],\n",
" [-7.4286e+00, -1.3714e+01, -5.0032e+00, 2.9967e+00, -1.2270e+00,\n",
" 5.8362e+00, 1.2156e+01, -9.0166e+00, 1.4145e+00, 4.2820e+00],\n",
" [ 5.8823e+00, -9.7926e+00, -9.1872e+00, 2.5901e+00, 6.3030e+00,\n",
" 1.0537e+01, 1.2890e+01, -1.2061e+01, 2.9509e+00, -1.3719e+01],\n",
" [-1.3296e+01, -8.6652e+00, 6.9771e+00, 9.2071e+00, 8.1628e-01,\n",
" 2.0125e+00, 2.1903e+01, 6.2237e+00, -1.4455e+00, -5.8654e+00],\n",
" [-7.6209e-01, -1.3498e+01, -5.5896e+00, -6.2977e+00, 1.0641e+01,\n",
" 8.7656e+00, 2.6984e+00, -2.1343e+01, 1.4336e+01, -9.6059e+00],\n",
" [-2.5177e+00, -1.8531e+01, -7.9338e+00, -1.1004e+00, 1.0198e+01,\n",
" 2.9311e-01, 9.9629e-01, -1.4012e+01, -1.0661e+00, -9.9426e+00],\n",
" [-7.0427e+00, -1.4266e+01, -8.9358e+00, 2.5929e+00, -2.4176e+00,\n",
" 2.2577e+00, 9.4721e+00, -1.0505e+01, 4.1487e+00, -2.2372e+00],\n",
" [-7.7164e+00, -1.8757e+01, -4.3946e-01, 1.0958e+00, 7.6625e+00,\n",
" 5.4620e+00, 5.1767e+00, -1.0525e+01, 1.6387e+00, -2.0864e+01],\n",
" [-9.1393e+00, -1.6572e+01, -8.0304e+00, -5.6745e+00, -3.0539e+00,\n",
" 6.3855e+00, 7.1456e+00, -1.2734e+01, 9.3335e+00, -5.1786e+00],\n",
" [-6.2210e+00, -1.7723e+01, -5.2040e+00, 2.3593e+00, 6.2196e+00,\n",
" 4.0063e+00, 9.7848e+00, -1.6621e+01, 4.3381e+00, 1.5232e+00],\n",
" [-9.6877e+00, -1.5638e+01, -1.1743e+01, 1.2100e+00, 5.6268e+00,\n",
" 3.1302e+00, 8.8238e+00, -9.7051e+00, 2.4356e+00, -1.5250e+01],\n",
" [-2.1409e+00, -1.4721e+01, 5.1321e+00, 7.5552e+00, 1.2496e+01,\n",
" 5.8926e-01, 5.1757e+00, -1.7918e+01, 4.0343e+00, -1.3427e+01],\n",
" [-1.3561e+00, -1.6037e+01, -4.3979e+00, 1.1734e+00, 1.4232e+01,\n",
" 2.4828e-01, 1.1959e+01, -1.4596e+01, 4.8290e+00, -8.7997e+00],\n",
" [-1.2300e+01, -2.4365e+00, -1.4827e+01, 2.8632e+00, 5.4879e+00,\n",
" 2.1031e+00, 1.4717e+01, -1.1212e+01, 1.0257e+01, -1.3586e+01],\n",
" [ 2.1925e+00, -9.7994e+00, 1.2788e+00, 5.1254e+00, 3.2195e+00,\n",
" -1.8118e+00, 5.7633e+00, -9.0045e+00, 4.8066e+00, -2.1188e+01],\n",
" [-8.5935e+00, -1.2186e+01, -1.9912e+00, -1.1423e+01, -1.0163e+00,\n",
" 4.0948e+00, 1.4590e+01, -1.7241e+01, 8.1954e+00, -2.3131e+00],\n",
" [-1.3321e+01, -2.2347e+01, -1.1324e+01, -2.8585e+00, -1.7610e+00,\n",
" -1.4794e+00, -6.2066e+00, -1.2875e+01, 4.7304e-01, -8.0901e+00],\n",
" [-7.5044e+00, -1.4883e+01, -3.0873e+00, -1.0818e+01, 2.0494e+00,\n",
" 1.8913e+00, 7.3184e+00, -1.5951e+01, 1.4538e+01, -6.7872e+00],\n",
" [-1.2156e+00, -5.1311e+00, -4.3359e+00, -1.1958e+00, 1.0366e+01,\n",
" 5.5732e+00, 1.4195e+01, -4.5258e+00, 1.6332e+00, -7.1058e+00],\n",
" [-5.1448e+00, -1.6311e+01, -4.1506e+00, -1.6945e+00, 7.2501e+00,\n",
" 5.5562e+00, 1.2209e+01, -1.3599e+01, 8.7059e-01, -6.6816e+00],\n",
" [-8.9986e+00, -1.1688e+01, -3.2560e+00, -3.5475e+00, 7.3344e+00,\n",
" -1.1153e+00, 1.2423e+01, -1.7140e+01, -1.9215e-02, -1.0566e+01],\n",
" [-1.2606e+01, -1.0380e+01, 4.4335e+00, -7.6157e+00, 8.5795e+00,\n",
" -6.5387e-01, -1.2524e+00, -1.8248e+01, -1.0656e+00, -4.6619e+00],\n",
" [-7.3017e+00, -1.4231e+01, -2.9509e+00, 1.7765e+00, 4.6566e+00,\n",
" 7.5279e-01, 8.5910e+00, -1.3277e+01, 9.7889e+00, -5.5621e+00],\n",
" [ 5.1653e-01, -1.3260e+01, -1.4126e+01, 3.4647e+00, -2.2583e-01,\n",
" -1.1289e+00, 1.4153e+00, -1.2826e+01, 7.3447e-01, 2.6117e+00],\n",
" [-3.0226e+00, -9.9246e+00, 1.0827e+00, 3.9424e+00, 4.5778e+00,\n",
" 8.9808e+00, 9.8348e+00, -1.4746e+01, 1.5857e+00, -4.8893e+00],\n",
" [-9.3943e+00, -1.7144e+01, -4.7724e+00, 2.9853e+00, 1.3025e+00,\n",
" 1.1826e+01, 1.3799e+01, -1.4311e+01, -3.9770e+00, -1.4685e+01],\n",
" [-1.8395e+00, -6.3027e+00, -1.2104e+01, 1.6868e+00, -2.7905e+00,\n",
" -2.5153e+00, 1.0880e+01, -1.7858e+01, -1.0252e+00, -5.6273e+00],\n",
" [-2.8564e+00, -9.6303e+00, -4.4335e-01, -4.5604e+00, 5.3632e+00,\n",
" 3.5592e+00, 1.6210e+01, -1.5872e+01, 2.1672e+00, -4.8131e+00],\n",
" [-9.6116e+00, -1.4696e+01, 1.7920e+00, 1.1261e-01, 7.3586e+00,\n",
" 2.2418e+00, 1.1091e+01, -1.5304e+01, -3.7681e+00, -6.6355e+00],\n",
" [-1.4531e+01, -9.3600e+00, 2.8657e+00, 1.3826e+01, 7.2516e+00,\n",
" -6.7883e+00, 1.1779e+01, -1.0356e+01, -2.9370e+00, -1.0907e+00],\n",
" [-7.9405e+00, -1.5890e+01, -1.9004e+00, -1.0034e+01, 5.4257e+00,\n",
" 3.7308e+00, 3.1009e+00, -1.9428e+01, -4.8535e+00, -7.3776e+00],\n",
" [-9.4365e+00, -1.1759e+01, -8.0134e+00, 1.0730e+01, 5.9662e+00,\n",
" -5.9889e+00, 1.4758e+01, -1.1842e+01, -1.1102e+00, -1.4020e+01],\n",
" [-4.7297e+00, -1.2402e+01, 4.2223e+00, 4.0652e+00, 2.5618e+00,\n",
" -9.0349e-01, 7.8768e+00, -8.6835e+00, -2.4806e-01, -9.2723e+00],\n",
" [-1.8740e+01, -1.0421e+01, 1.0059e+00, 2.7625e+00, 1.1302e+00,\n",
" -6.0392e+00, 6.9080e-01, -1.7562e+01, -6.1762e+00, -8.8501e+00],\n",
" [-2.1741e+00, -1.0011e+01, 1.9196e-01, 2.2942e+00, 2.2026e+00,\n",
" -3.4104e+00, 1.8980e+01, -1.3834e+01, 5.2519e+00, -4.4947e+00],\n",
" [-7.8901e+00, -2.5314e+01, -1.4787e+01, -5.3091e+00, -4.0048e+00,\n",
" 5.7036e+00, 1.1725e+01, -3.1613e+00, -6.0867e+00, -1.0160e+01],\n",
" [-6.7829e+00, -1.5411e+01, -5.9446e+00, -1.0443e+00, 8.4399e+00,\n",
" 1.1174e+01, 1.7413e+01, -1.5182e+01, 6.2646e+00, -5.8763e+00],\n",
" [-5.3380e+00, -1.9453e+01, -3.7069e+00, -5.2592e+00, -1.3728e-01,\n",
" 8.3356e+00, 9.9740e+00, -1.2532e+01, 3.1678e+00, -1.0057e+01],\n",
" [-1.0315e+01, -1.9546e+01, -3.4752e+00, 9.3489e-01, 1.1225e+01,\n",
" 6.1291e+00, 5.2841e+00, -1.1586e+01, -4.4080e+00, -8.3770e+00],\n",
" [-4.7619e+00, -1.6749e+01, 4.3093e+00, 7.7299e+00, 6.5368e+00,\n",
" -3.4615e-01, 1.5660e+01, -8.6438e+00, 3.1611e-01, -1.4415e+01],\n",
" [-2.0070e+00, -2.7576e+00, 3.5579e-01, 1.0939e+01, 7.1884e+00,\n",
" 7.4513e-01, 6.5265e+00, -4.4714e+00, 1.5242e+00, -1.6862e+01],\n",
" [-9.2632e+00, -2.3852e+01, -1.3293e+01, -3.1975e+00, 1.5766e+00,\n",
" -4.3432e+00, 1.9281e+00, -2.1285e+01, -1.3180e+01, -7.5578e+00],\n",
" [-4.8375e+00, -4.0539e+00, 6.4372e+00, -5.8421e+00, 1.5553e-01,\n",
" 6.2599e+00, 7.5777e+00, -2.1493e+01, 3.1252e+00, 5.9446e+00],\n",
" [-6.2229e+00, -1.1374e+01, -1.5254e+00, 5.5453e+00, 5.2497e+00,\n",
" 1.4723e+00, 7.6991e+00, -2.1742e+01, -3.1884e+00, -6.5272e+00],\n",
" [-1.0195e+01, -2.0028e+01, -1.2318e+01, 1.1833e-01, 5.7554e+00,\n",
" 1.1150e+01, 7.0612e-01, -2.0626e+01, 1.9079e+00, 3.9760e+00],\n",
" [-4.9120e+00, -1.1047e+01, -2.0498e+00, 4.2257e+00, 8.1903e+00,\n",
" -2.2026e+00, 1.4212e+01, -9.2734e+00, 4.5625e+00, -1.2413e+01],\n",
" [ 2.6906e+00, -1.9524e+01, -6.4190e+00, 4.2679e+00, 7.3993e+00,\n",
" -1.0094e+00, 7.7773e+00, -1.5883e+01, 4.7779e+00, -7.7939e+00],\n",
" [-6.0641e+00, -4.0632e+00, -7.4551e+00, 1.0069e+01, 8.9059e+00,\n",
" -1.7938e+00, 9.2402e+00, -2.9644e+00, -3.3070e+00, -1.8040e+01],\n",
" [-2.0029e+00, -2.4876e+01, -1.4870e+01, -8.7239e+00, 1.6563e-01,\n",
" 2.9899e+00, 7.6359e+00, -1.6602e+01, -5.8871e+00, -3.0714e+00]])\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gljt4AhknqTa",
"colab_type": "text"
},
"source": [
"Now we have 10 outputs for our network. We want to pass in an image to our network and get out a probability distribution over the classes that tells us the likely class(es) the image belongs to. Something that looks like this:\n",
"<img src='assets/image_distribution.png' width=500px>\n",
"\n",
"Here we see that the probability for each class is roughly the same. This is representing an untrained network, it hasn't seen any data yet so it just returns a uniform distribution with equal probabilities for each class.\n",
"\n",
"To calculate this probability distribution, we often use the [**softmax** function](https://en.wikipedia.org/wiki/Softmax_function). Mathematically this looks like\n",
"\n",
"$$\n",
"\\Large \\sigma(x_i) = \\cfrac{e^{x_i}}{\\sum_k^K{e^{x_k}}}\n",
"$$\n",
"\n",
"What this does is squish each input $x_i$ between 0 and 1 and normalizes the values to give you a proper probability distribution where the probabilites sum up to one.\n",
"\n",
"> **Exercise:** Implement a function `softmax` that performs the softmax calculation and returns probability distributions for each example in the batch. Note that you'll need to pay attention to the shapes when doing this. If you have a tensor `a` with shape `(64, 10)` and a tensor `b` with shape `(64,)`, doing `a/b` will give you an error because PyTorch will try to do the division across the columns (called broadcasting) but you'll get a size mismatch. The way to think about this is for each of the 64 examples, you only want to divide by one value, the sum in the denominator. So you need `b` to have a shape of `(64, 1)`. This way PyTorch will divide the 10 values in each row of `a` by the one value in each row of `b`. Pay attention to how you take the sum as well. You'll need to define the `dim` keyword in `torch.sum`. Setting `dim=0` takes the sum across the rows while `dim=1` takes the sum across the columns."
]
},
{
"cell_type": "code",
"metadata": {
"id": "T4pHtiu7nqTb",
"colab_type": "code",
"outputId": "84f46796-f75a-45f0-a71e-f9a0954c496f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 168
}
},
"source": [
"def softmax(x):\n",
" return torch.exp(x)/torch.sum(torch.exp(x) , dim=1).view(-1,1)\n",
" ## TODO: Implement the softmax function here\n",
"\n",
"# Here, out should be the output of the network in the previous excercise with shape (64,10)\n",
"probabilities = softmax(out)\n",
"\n",
"# Does it have the right shape? Should be (64, 10)\n",
"print(probabilities.shape)\n",
"# Does it sum to 1?\n",
"print(probabilities.sum(dim=1))"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": [
"torch.Size([64, 10])\n",
"tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,\n",
" 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,\n",
" 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,\n",
" 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,\n",
" 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,\n",
" 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,\n",
" 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,\n",
" 1.0000])\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dlvDJOcsnqTd",
"colab_type": "text"
},
"source": [
"## Building networks with PyTorch\n",
"\n",
"PyTorch provides a module `nn` that makes building networks much simpler. Here I'll show you how to build the same one as above with 784 inputs, 256 hidden units, 10 output units and a softmax output."
]
},
{
"cell_type": "code",
"metadata": {
"id": "VkVYf7funqTe",
"colab_type": "code",
"colab": {}
},
"source": [
"from torch import nn"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "7qH0l4dpnqTg",
"colab_type": "code",
"colab": {}
},
"source": [
"class Network(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" \n",
" # Inputs to hidden layer linear transformation\n",
" self.hidden = nn.Linear(784, 256)\n",
" # Output layer, 10 units - one for each digit\n",
" self.output = nn.Linear(256, 10)\n",
" \n",
" # Define sigmoid activation and softmax output \n",
" self.sigmoid = nn.Sigmoid()\n",
" self.softmax = nn.Softmax(dim=1)\n",
" \n",
" def forward(self, x):\n",
" # Pass the input tensor through each of our operations\n",
" x = self.hidden(x)\n",
" x = self.sigmoid(x)\n",
" x = self.output(x)\n",
" x = self.softmax(x)\n",
" \n",
" return x"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "yARJ802FnqTi",
"colab_type": "text"
},
"source": [
"Let's go through this bit by bit.\n",
"\n",
"```python\n",
"class Network(nn.Module):\n",
"```\n",
"\n",
"Here we're inheriting from `nn.Module`. Combined with `super().__init__()` this creates a class that tracks the architecture and provides a lot of useful methods and attributes. It is mandatory to inherit from `nn.Module` when you're creating a class for your network. The name of the class itself can be anything.\n",
"\n",
"```python\n",
"self.hidden = nn.Linear(784, 256)\n",
"```\n",
"\n",
"This line creates a module for a linear transformation, $x\\mathbf{W} + b$, with 784 inputs and 256 outputs and assigns it to `self.hidden`. The module automatically creates the weight and bias tensors which we'll use in the `forward` method. You can access the weight and bias tensors once the network (`net`) is created with `net.hidden.weight` and `net.hidden.bias`.\n",
"\n",
"```python\n",
"self.output = nn.Linear(256, 10)\n",
"```\n",
"\n",
"Similarly, this creates another linear transformation with 256 inputs and 10 outputs.\n",
"\n",
"```python\n",
"self.sigmoid = nn.Sigmoid()\n",
"self.softmax = nn.Softmax(dim=1)\n",
"```\n",
"\n",
"Here I defined operations for the sigmoid activation and softmax output. Setting `dim=1` in `nn.Softmax(dim=1)` calculates softmax across the columns.\n",
"\n",
"```python\n",
"def forward(self, x):\n",
"```\n",
"\n",
"PyTorch networks created with `nn.Module` must have a `forward` method defined. It takes in a tensor `x` and passes it through the operations you defined in the `__init__` method.\n",
"\n",
"```python\n",
"x = self.hidden(x)\n",
"x = self.sigmoid(x)\n",
"x = self.output(x)\n",
"x = self.softmax(x)\n",
"```\n",
"\n",
"Here the input tensor `x` is passed through each operation and reassigned to `x`. We can see that the input tensor goes through the hidden layer, then a sigmoid function, then the output layer, and finally the softmax function. It doesn't matter what you name the variables here, as long as the inputs and outputs of the operations match the network architecture you want to build. The order in which you define things in the `__init__` method doesn't matter, but you'll need to sequence the operations correctly in the `forward` method.\n",
"\n",
"Now we can create a `Network` object."
]
},
{
"cell_type": "code",
"metadata": {
"id": "YUaM45FVnqTi",
"colab_type": "code",
"outputId": "7c2dd7aa-eea2-4ac6-8f26-dd8642d691d5",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 118
}
},
"source": [
"# Create the network and look at it's text representation\n",
"model = Network()\n",
"model"
],
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Network(\n",
" (hidden): Linear(in_features=784, out_features=256, bias=True)\n",
" (output): Linear(in_features=256, out_features=10, bias=True)\n",
" (sigmoid): Sigmoid()\n",
" (softmax): Softmax(dim=1)\n",
")"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AfIxcUErnqTk",
"colab_type": "text"
},
"source": [
"You can define the network somewhat more concisely and clearly using the `torch.nn.functional` module. This is the most common way you'll see networks defined as many operations are simple element-wise functions. We normally import this module as `F`, `import torch.nn.functional as F`."
]
},
{
"cell_type": "code",
"metadata": {
"id": "n3XRGvgWnqTk",
"colab_type": "code",
"colab": {}
},
"source": [
"import torch.nn.functional as F\n",
"\n",
"class Network(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" # Inputs to hidden layer linear transformation\n",
" self.hidden = nn.Linear(784, 256)\n",
" # Output layer, 10 units - one for each digit\n",
" self.output = nn.Linear(256, 10)\n",
" \n",
" def forward(self, x):\n",
" # Hidden layer with sigmoid activation\n",
" x = F.sigmoid(self.hidden(x))\n",
" # Output layer with softmax activation\n",
" x = F.softmax(self.output(x), dim=1)\n",
" \n",
" return x"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "NmF-HHP3nqTm",
"colab_type": "text"
},
"source": [
"### Activation functions\n",
"\n",
"So far we've only been looking at the sigmoid activation function, but in general any function can be used as an activation function. The only requirement is that for a network to approximate a non-linear function, the activation functions must be non-linear. Here are a few more examples of common activation functions: Tanh (hyperbolic tangent), and ReLU (rectified linear unit).\n",
"\n",
"<img src=\"assets/activation.png\" width=700px>\n",
"\n",
"In practice, the ReLU function is used almost exclusively as the activation function for hidden layers."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qkhgO-8-nqTn",
"colab_type": "text"
},
"source": [
"### Your Turn to Build a Network\n",
"\n",
"<img src=\"assets/mlp_mnist.png\" width=600px>\n",
"\n",
"> **Exercise:** Create a network with 784 input units, a hidden layer with 128 units and a ReLU activation, then a hidden layer with 64 units and a ReLU activation, and finally an output layer with a softmax activation as shown above. You can use a ReLU activation with the `nn.ReLU` module or `F.relu` function.\n",
"\n",
"It's good practice to name your layers by their type of network, for instance 'fc' to represent a fully-connected layer. As you code your solution, use `fc1`, `fc2`, and `fc3` as your layer names."
]
},
{
"cell_type": "code",
"metadata": {
"scrolled": true,
"id": "K2SQwMkdnqTn",
"colab_type": "code",
"outputId": "60a2e2ee-6dd5-4e46-c04b-249b8dc011a0",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 101
}
},
"source": [
"## Your solution here\n",
"import torch.nn.functional as F\n",
"\n",
"class Network(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.fc1 = nn.Linear(728,128)\n",
" self.fc2 = nn.Linear(128,64)\n",
" self.fc3 = nn.Linear(64,10)\n",
" \n",
" def forward(self,x):\n",
" x = self.fc1(x)\n",
" x = F.relu(x)\n",
" x = self.fc2(x)\n",
" x = F.relu(x)\n",
" x = self.fc3(x)\n",
" x = F.softmax(x , dim=1)\n",
"\n",
" return x\n",
"\n",
" # x = F.sigmoid(self.hidden(x))\n",
" # x = F.softmax(self.output(x),dim=1)\n",
"\n",
"model=Network()\n",
"model\n"
],
"execution_count": 11,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Network(\n",
" (fc1): Linear(in_features=728, out_features=128, bias=True)\n",
" (fc2): Linear(in_features=128, out_features=64, bias=True)\n",
" (fc3): Linear(in_features=64, out_features=10, bias=True)\n",
")"
]
},
"metadata": {
"tags": []
},
"execution_count": 11
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cnOOggHgnqTp",
"colab_type": "text"
},
"source": [
"### Initializing weights and biases\n",
"\n",
"The weights and such are automatically initialized for you, but it's possible to customize how they are initialized. The weights and biases are tensors attached to the layer you defined, you can get them with `model.fc1.weight` for instance."
]
},
{
"cell_type": "code",
"metadata": {
"id": "uhTTysGmnqTp",
"colab_type": "code",
"outputId": "0c316224-03ba-4dcd-873d-484aebc57f10",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 470
}
},
"source": [
"print(model.fc1.weight)\n",
"print(model.fc1.bias)"
],
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"text": [
"Parameter containing:\n",
"tensor([[-0.0354, 0.0126, -0.0006, ..., 0.0265, 0.0192, 0.0349],\n",
" [-0.0096, -0.0307, -0.0369, ..., -0.0013, 0.0331, -0.0064],\n",
" [ 0.0194, 0.0026, -0.0294, ..., 0.0034, 0.0363, 0.0264],\n",
" ...,\n",
" [-0.0175, -0.0186, 0.0308, ..., 0.0022, 0.0229, 0.0083],\n",
" [ 0.0318, -0.0066, 0.0174, ..., 0.0290, -0.0243, -0.0181],\n",
" [-0.0038, -0.0358, 0.0072, ..., 0.0082, 0.0217, 0.0127]],\n",
" requires_grad=True)\n",
"Parameter containing:\n",
"tensor([ 0.0262, 0.0154, -0.0006, -0.0234, 0.0103, 0.0138, 0.0111, 0.0196,\n",
" -0.0320, -0.0287, 0.0133, 0.0035, 0.0144, -0.0370, -0.0249, -0.0244,\n",
" 0.0242, -0.0197, -0.0225, 0.0352, 0.0249, -0.0049, -0.0267, -0.0342,\n",
" 0.0287, 0.0236, 0.0062, -0.0223, 0.0241, 0.0190, 0.0094, 0.0104,\n",
" 0.0167, 0.0203, 0.0161, -0.0082, 0.0092, -0.0118, -0.0201, -0.0153,\n",
" 0.0247, 0.0077, -0.0339, 0.0273, -0.0128, 0.0271, -0.0149, -0.0210,\n",
" -0.0250, -0.0257, 0.0326, -0.0137, 0.0348, -0.0202, -0.0367, 0.0006,\n",
" -0.0046, -0.0165, -0.0173, -0.0154, -0.0274, 0.0243, -0.0203, 0.0248,\n",
" 0.0332, -0.0171, -0.0256, 0.0260, -0.0213, 0.0069, -0.0099, 0.0234,\n",
" 0.0349, -0.0319, 0.0192, 0.0357, 0.0187, -0.0210, -0.0200, 0.0143,\n",
" -0.0029, -0.0026, -0.0013, 0.0367, 0.0276, -0.0348, -0.0360, 0.0083,\n",
" -0.0040, -0.0261, -0.0242, -0.0012, -0.0128, -0.0343, -0.0226, 0.0307,\n",
" -0.0214, 0.0019, -0.0188, 0.0191, -0.0338, -0.0219, -0.0362, -0.0367,\n",
" 0.0328, -0.0160, 0.0025, 0.0135, -0.0216, 0.0193, 0.0225, 0.0212,\n",
" 0.0241, 0.0040, -0.0352, 0.0189, -0.0042, -0.0345, 0.0199, 0.0196,\n",
" 0.0174, 0.0128, 0.0287, 0.0128, -0.0068, -0.0108, 0.0131, -0.0220],\n",
" requires_grad=True)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pIx6CXV7nqTr",
"colab_type": "text"
},
"source": [
"For custom initialization, we want to modify these tensors in place. These are actually autograd *Variables*, so we need to get back the actual tensors with `model.fc1.weight.data`. Once we have the tensors, we can fill them with zeros (for biases) or random normal values."
]
},
{
"cell_type": "code",
"metadata": {
"id": "67iYGuuNnqTs",
"colab_type": "code",
"outputId": "e6e2dc12-11fd-4c77-9574-b1eda020ff2d",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 118
}
},
"source": [
"# Set biases to all zeros\n",
"model.fc1.bias.data.fill_(0)"
],
"execution_count": 13,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0.])"
]
},
"metadata": {
"tags": []
},
"execution_count": 13
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "KUzE5S9ynqTu",
"colab_type": "code",
"outputId": "38a23bcc-3954-4cfb-ae5e-e65313bd02fd",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 134
}
},
"source": [
"# sample from random normal with standard dev = 0.01\n",
"model.fc1.weight.data.normal_(std=0.01)"
],
"execution_count": 14,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 0.0136, -0.0094, 0.0071, ..., 0.0095, -0.0190, 0.0042],\n",
" [ 0.0012, 0.0125, 0.0167, ..., 0.0186, 0.0085, -0.0025],\n",
" [ 0.0020, 0.0200, 0.0023, ..., -0.0104, 0.0014, -0.0030],\n",
" ...,\n",
" [-0.0119, -0.0040, -0.0004, ..., -0.0118, -0.0165, 0.0062],\n",
" [-0.0010, 0.0185, 0.0024, ..., 0.0013, -0.0015, -0.0017],\n",
" [-0.0103, 0.0176, -0.0136, ..., -0.0011, -0.0083, 0.0204]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 14
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X-1LleEFnqTx",
"colab_type": "text"
},
"source": [
"### Forward pass\n",
"\n",
"Now that we have a network, let's see what happens when we pass in an image."
]
},
{
"cell_type": "code",
"metadata": {
"id": "X_C9uH8AnqTy",
"colab_type": "code",
"outputId": "313f7b94-8b7b-4f91-c233-bdb077a1ebba",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 370
}
},
"source": [
"\n",
"# Grab some data \n",
"dataiter = iter(trainloader)\n",
"images, labels = dataiter.next()\n",
"\n",
"# Resize images into a 1D vector, new shape is (batch size, color channels, image pixels) \n",
"images.resize_(64, 1, 784)\n",
"# or images.resize_(images.shape[0], 1, 784) to automatically get batch size\n",
"\n",
"# Forward pass through the network\n",
"img_idx=0\n",
"ps = model.forward(images[img_idx,:])\n",
"\n",
"img = images[img_idx]\n",
"helper.view_classify(img.view(1, 28, 28), ps)"
],
"execution_count": 22,
"outputs": [
{
"output_type": "error",
"ename": "RuntimeError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-22-8dd4a1b4f15e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;31m# Forward pass through the network\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mimg_idx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mps\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimages\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mimg_idx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimages\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mimg_idx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-11-bcfdb10a5652>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 539\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 540\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 541\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 542\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 543\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 87\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mlinear\u001b[0;34m(input, weight, bias)\u001b[0m\n\u001b[1;32m 1368\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1369\u001b[0m \u001b[0;31m# fused op is marginally faster\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1370\u001b[0;31m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maddmm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1371\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1372\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mRuntimeError\u001b[0m: size mismatch, m1: [1 x 784], m2: [728 x 128] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:197"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CgqCyX0UnqT0",
"colab_type": "text"
},
"source": [
"As you can see above, our network has basically no idea what this digit is. It's because we haven't trained it yet, all the weights are random!\n",
"\n",
"### Using `nn.Sequential`\n",
"\n",
"PyTorch provides a convenient way to build networks like this where a tensor is passed sequentially through operations, `nn.Sequential` ([documentation](https://pytorch.org/docs/master/nn.html#torch.nn.Sequential)). Using this to build the equivalent network:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "7vQWjrARnqT1",
"colab_type": "code",
"colab": {}
},
"source": [
"# Hyperparameters for our network\n",
"input_size = 784\n",
"hidden_sizes = [128, 64]\n",
"output_size = 10\n",
"\n",
"# Build a feed-forward network\n",
"model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_sizes[0], hidden_sizes[1]),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_sizes[1], output_size),\n",
" nn.Softmax(dim=1))\n",
"print(model)\n",
"\n",
"# Forward pass through the network and display output\n",
"images, labels = next(iter(trainloader))\n",
"images.resize_(images.shape[0], 1, 784)\n",
"ps = model.forward(images[0,:])\n",
"helper.view_classify(images[0].view(1, 28, 28), ps)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "bh81XCdcnqT4",
"colab_type": "text"
},
"source": [
"Here our model is the same as before: 784 input units, a hidden layer with 128 units, ReLU activation, 64 unit hidden layer, another ReLU, then the output layer with 10 units, and the softmax output.\n",
"\n",
"The operations are available by passing in the appropriate index. For example, if you want to get first Linear operation and look at the weights, you'd use `model[0]`."
]
},
{
"cell_type": "code",
"metadata": {
"id": "mOAwk-MMnqT4",
"colab_type": "code",
"colab": {}
},
"source": [
"print(model[0])\n",
"model[0].weight"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "y4ikqDignqT6",
"colab_type": "text"
},
"source": [
"You can also pass in an `OrderedDict` to name the individual layers and operations, instead of using incremental integers. Note that dictionary keys must be unique, so _each operation must have a different name_."
]
},
{
"cell_type": "code",
"metadata": {
"id": "fFAKLFpMnqT6",
"colab_type": "code",
"colab": {}
},
"source": [
"from collections import OrderedDict\n",
"model = nn.Sequential(OrderedDict([\n",
" ('fc1', nn.Linear(input_size, hidden_sizes[0])),\n",
" ('relu1', nn.ReLU()),\n",
" ('fc2', nn.Linear(hidden_sizes[0], hidden_sizes[1])),\n",
" ('relu2', nn.ReLU()),\n",
" ('output', nn.Linear(hidden_sizes[1], output_size)),\n",
" ('softmax', nn.Softmax(dim=1))]))\n",
"model"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "m2esX72DnqT8",
"colab_type": "text"
},
"source": [
"Now you can access layers either by integer or the name"
]
},
{
"cell_type": "code",
"metadata": {
"id": "vFzlOL0ZnqT9",
"colab_type": "code",
"colab": {}
},
"source": [
"print(model[0])\n",
"print(model.fc1)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "-NoeozKSnqUB",
"colab_type": "text"
},
"source": [
"In the next notebook, we'll see how we can train a neural network to accuractly predict the numbers appearing in the MNIST images."
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment