Created
August 3, 2018 18:35
-
-
Save demacdolincoln/56a00fd2d9111b433e6c2b791290cfc1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Simples exemplo de uso do pytorch para classificar números manuscritos\n", | |
"\n", | |
"**leituras recomendadas (e que tambem precisei para escrever esse script):**\n", | |
"\n", | |
"* https://matheusfacure.github.io/2017/05/15/deep-ff-ann-pytorch/\n", | |
"* http://deeplearningbook.com.br/funcao-de-ativacao/\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from torch import autograd, nn\n", | |
"from torch.nn import functional as F\n", | |
"\n", | |
"from sklearn import datasets\n", | |
"import numpy as np\n", | |
"\n", | |
"import matplotlib.pyplot as plt" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Model(nn.Module):\n", | |
" def __init__(self, input_size, hidden_size, num_classes):\n", | |
" super().__init__()\n", | |
" self.in_to_h1 = nn.Linear(input_size, hidden_size)\n", | |
" self.h1_to_h2 = nn.Linear(hidden_size, hidden_size)\n", | |
" self.h2_to_out = nn.Linear(hidden_size, num_classes)\n", | |
" \n", | |
" def forward(self, x):\n", | |
" x = F.relu(self.in_to_h1(x))\n", | |
" x = F.relu(self.h1_to_h2(x))\n", | |
" x = self.h2_to_out(x)\n", | |
" return x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ds = datasets.load_digits()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x_batch = ds.data.astype(np.float32)\n", | |
"y_batch = ds.target\n", | |
"x_batch = torch.FloatTensor(x_batch.tolist())\n", | |
"y_batch = torch.LongTensor(y_batch.tolist())\n", | |
"\n", | |
"x_batch = autograd.Variable(x_batch)\n", | |
"y_batch = autograd.Variable(y_batch)\n", | |
" \n", | |
"x_batch, y_batch = autograd.Variable(x_batch, ), autograd.Variable(y_batch)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"batch_size = x_batch.shape[0]\n", | |
"input_size = x_batch.shape[1]\n", | |
"hidden_size = 128\n", | |
"num_classes = len(ds.target_names)\n", | |
"learning_rate = 1e-5" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"inpt: tensor([[ 0., 0., 5., ..., 0., 0., 0.],\n", | |
" [ 0., 0., 0., ..., 10., 0., 0.],\n", | |
" [ 0., 0., 0., ..., 16., 9., 0.],\n", | |
" ...,\n", | |
" [ 0., 0., 1., ..., 6., 0., 0.],\n", | |
" [ 0., 0., 2., ..., 12., 0., 0.],\n", | |
" [ 0., 0., 10., ..., 12., 1., 0.]])\n", | |
"target: tensor([[ 0, 1, 2, ..., 8, 9, 8]])\n" | |
] | |
} | |
], | |
"source": [ | |
"print('inpt: ', x_batch)\n", | |
"print('target: ', y_batch.view(1, -1))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Model(\n", | |
" (in_to_h1): Linear(in_features=64, out_features=128, bias=True)\n", | |
" (h1_to_h2): Linear(in_features=128, out_features=128, bias=True)\n", | |
" (h2_to_out): Linear(in_features=128, out_features=10, bias=True)\n", | |
")\n" | |
] | |
} | |
], | |
"source": [ | |
"model = Model(input_size, hidden_size, num_classes)\n", | |
"print(model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch: 0 | loss:2.4547500610351562\n", | |
"epoch: 10 | loss:0.1502559632062912\n", | |
"epoch: 20 | loss:0.040915049612522125\n", | |
"epoch: 30 | loss:0.010044598951935768\n", | |
"epoch: 40 | loss:0.0028940257616341114\n", | |
"epoch: 50 | loss:0.0011215369449928403\n", | |
"epoch: 60 | loss:0.000569865689612925\n", | |
"epoch: 70 | loss:0.00037850209628231823\n", | |
"epoch: 80 | loss:0.00028599362121894956\n", | |
"epoch: 90 | loss:0.0002365583786740899\n", | |
"CPU times: user 29 s, sys: 243 ms, total: 29.3 s\n", | |
"Wall time: 28.9 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"criterion = nn.CrossEntropyLoss() # define o custo de entropia cruzada\n", | |
"optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)\n", | |
"# opt = torch.optim.Adam(params=model.parameters(), lr=learning_rate)\n", | |
"\n", | |
"for epoch in range(100):\n", | |
" optimizer.zero_grad()\n", | |
" \n", | |
" logit = model(x_batch)\n", | |
" loss = criterion(logit, y_batch)\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" if epoch % 10 == 0:\n", | |
" print(f'epoch: {epoch} | loss:{loss.item()}')\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"out = model(x_batch)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tensor de saída: tensor([-36.7698, -6.9950, 31.8018, -7.3230, -18.5295, -4.9313,\n", | |
" -25.2274, -9.0327, -3.7554, 2.5035])\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 1080x360 with 2 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"index = 268\n", | |
"\n", | |
"print(\"tensor de saída: \", out[index])\n", | |
"\n", | |
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))\n", | |
"ax1.matshow(ds.images[index], cmap=plt.cm.gray_r)\n", | |
"ax1.set_title(f\"valor esperado: {y_batch[index]}\\n\")\n", | |
"\n", | |
"ax2.plot(out[index].detach().numpy())\n", | |
"ax2.grid(True)\n", | |
"ax2.set_title(f\"valor previsto: {out[index].argmax()}\\n\")\n", | |
"plt.setp(ax2, xticks=list(range(num_classes)));" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment