Skip to content

Instantly share code, notes, and snippets.

@yungwarlock
Last active June 26, 2024 10:48
Show Gist options
  • Save yungwarlock/34891fc23650f16048715934f1fa8e1b to your computer and use it in GitHub Desktop.
Save yungwarlock/34891fc23650f16048715934f1fa8e1b to your computer and use it in GitHub Desktop.
learning-pytorch.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyNNZwvbehqHbsy47sJWfTC9",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/yungwarlock/34891fc23650f16048715934f1fa8e1b/learning-pytorch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "Ff54LIyrCObg"
},
"outputs": [],
"source": [
"import torch\n",
"import torchvision\n",
"\n",
"from torchvision import datasets, transforms"
]
},
{
"cell_type": "code",
"source": [
"train = datasets.MNIST(\"\", train=True, download=True,\n",
" transform=transforms.Compose([\n",
" transforms.ToTensor()\n",
" ]))\n",
"test = datasets.MNIST(\"\",\n",
" train=False, download=True,\n",
" transform=transforms.Compose([\n",
" transforms.ToTensor()\n",
" ]))\n",
"\n",
"testset = torch.utils.data.DataLoader(test, batch_size=10, shuffle=True)\n",
"trainset = torch.utils.data.DataLoader(train, batch_size=10, shuffle=True)"
],
"metadata": {
"id": "mON-CfnJCZC_",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "5d9eeda1-8424-4711-a90d-41417732cefc"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
"Failed to download (trying next):\n",
"HTTP Error 403: Forbidden\n",
"\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to MNIST/raw/train-images-idx3-ubyte.gz\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 9912422/9912422 [00:00<00:00, 15912637.29it/s]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Extracting MNIST/raw/train-images-idx3-ubyte.gz to MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
"Failed to download (trying next):\n",
"HTTP Error 403: Forbidden\n",
"\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to MNIST/raw/train-labels-idx1-ubyte.gz\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 28881/28881 [00:00<00:00, 478366.11it/s]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Extracting MNIST/raw/train-labels-idx1-ubyte.gz to MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
"Failed to download (trying next):\n",
"HTTP Error 403: Forbidden\n",
"\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to MNIST/raw/t10k-images-idx3-ubyte.gz\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 1648877/1648877 [00:03<00:00, 452011.72it/s]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Extracting MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
"Failed to download (trying next):\n",
"HTTP Error 403: Forbidden\n",
"\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 4542/4542 [00:00<00:00, 4720150.83it/s]"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Extracting MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST/raw\n",
"\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"INPUT = 28 * 28\n",
"OUTPUT = 10\n",
"\n",
"class Net(nn.Module):\n",
"\n",
" def __init__(self, input, output):\n",
" super().__init__()\n",
"\n",
" self.fc1 = nn.Linear(input, 64)\n",
" self.fc2 = nn.Linear(64, 64)\n",
" self.fc3 = nn.Linear(64, 64)\n",
" self.fc4 = nn.Linear(64, output)\n",
"\n",
" def forward(self, x):\n",
" x = self.fc1(x)\n",
" x = F.relu(x)\n",
"\n",
" x = self.fc2(x)\n",
" x = F.relu(x)\n",
"\n",
" x = self.fc3(x)\n",
" x = F.relu(x)\n",
"\n",
" x = self.fc4(x)\n",
" x = F.log_softmax(x, dim=1)\n",
"\n",
" return x\n",
"\n",
"\n",
"net = Net(INPUT, OUTPUT)\n",
"net"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "d869q_OfD40b",
"outputId": "f061baff-26c5-4679-c0c9-f36f002342b2"
},
"execution_count": 3,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Net(\n",
" (fc1): Linear(in_features=784, out_features=64, bias=True)\n",
" (fc2): Linear(in_features=64, out_features=64, bias=True)\n",
" (fc3): Linear(in_features=64, out_features=64, bias=True)\n",
" (fc4): Linear(in_features=64, out_features=10, bias=True)\n",
")"
]
},
"metadata": {},
"execution_count": 3
}
]
},
{
"cell_type": "code",
"source": [
"X = torch.rand((28, 28))\n",
"X = X.view(-1, 28*28)"
],
"metadata": {
"id": "tzlfWM-_IGkW"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"output = net(X)\n",
"output"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "SPAY8H08IOtL",
"outputId": "e8909968-c6e9-4016-98a5-a0a884ab82e9"
},
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[-2.2331, -2.3998, -2.2531, -2.2027, -2.3512, -2.3702, -2.3776, -2.2974,\n",
" -2.2209, -2.3437]], grad_fn=<LogSoftmaxBackward0>)"
]
},
"metadata": {},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"source": [
"torch.max(output)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0Uk3DFLgI7wN",
"outputId": "8776cd04-3449-47dd-f826-b340dbdbd622"
},
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor(-2.2027, grad_fn=<MaxBackward1>)"
]
},
"metadata": {},
"execution_count": 6
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment