Skip to content

Instantly share code, notes, and snippets.

@jbsilva
Created November 8, 2020 21:07
Show Gist options
  • Save jbsilva/b36081452488f743d311f3fe338c4308 to your computer and use it in GitHub Desktop.
Save jbsilva/b36081452488f743d311f3fe338c4308 to your computer and use it in GitHub Desktop.
MLP_Classification_MNIST.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "MLP_Classification_MNIST.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/jbsilva/b36081452488f743d311f3fe338c4308/mlp_classification_mnist.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Il7H73VGJ2No"
},
"source": [
"# Classificação do MNIST\n",
"\n",
"Vamos dar início à construção de um modelo capaz de classificar dígitos escritos a mão com mais de 90% de acurácia!!\n",
"\n",
"Relembrando o dataset, a imagem a seguir mostra alguns elementos pertencentes às dez classes do problema (dígitos de 0 a 9) <br>\n",
"![](https://upload.wikimedia.org/wikipedia/commons/thumb/2/27/MnistExamples.png/440px-MnistExamples.png)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "5LaoPWO9RhEq",
"outputId": "619b7fb6-c68f-44bc-9548-a73cb17e91fd",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"import torch\n",
"import torchvision\n",
"from torch import nn\n",
"from torch.utils.data import DataLoader\n",
"\n",
"from sklearn import metrics\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import numpy as np\n",
"import time\n",
"\n",
"sns.set_style('darkgrid')\n",
"\n",
"device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
"print(device)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"cuda\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5yNikfN0J2Np"
},
"source": [
"## Carregamento de dados.\n",
"\n",
"Para focar na implementação e treinamento da rede neural vamos utilizar o carregamento automático de datasets do **```torchvision```**. Vale muito conhecer o carregamento de dados do PyTorch, é bastante eficiente!\n",
"\n",
"Me cobrem de disponibilizar depois um tutorial de como carregar seus próprios dados ;)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "GWHd8oW15X2F"
},
"source": [
"train_set = torchvision.datasets.MNIST('.', train=True, \n",
" download=True,\n",
" transform=torchvision.transforms.ToTensor())\n",
"\n",
"test_set = torchvision.datasets.MNIST('.', train=True, \n",
" download=False,\n",
" transform=torchvision.transforms.ToTensor())\n",
"\n",
"train_loader = DataLoader(train_set,\n",
" batch_size=32,\n",
" shuffle=True)\n",
"\n",
"test_loader = DataLoader(test_set,\n",
" batch_size=32,\n",
" shuffle=False)\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "lVFJWjBR6HC_",
"outputId": "87154fbf-207e-47a7-9567-2902734cde33",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 298
}
},
"source": [
"fig, axs = plt.subplots(1,10, figsize=(14, 3))\n",
"\n",
"k = -1\n",
"for data, label in test_loader:\n",
" k += 1\n",
" if k > 9: break\n",
"\n",
" print(data.size(), label.size())\n",
" axs[k].imshow( data[0][0], cmap='gray' )\n",
" axs[k].set(title = str(label[0].item()), xticks=[], yticks=[] )\n",
"\n",
"plt.show() \n",
" "
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"torch.Size([32, 1, 28, 28]) torch.Size([32])\n",
"torch.Size([32, 1, 28, 28]) torch.Size([32])\n",
"torch.Size([32, 1, 28, 28]) torch.Size([32])\n",
"torch.Size([32, 1, 28, 28]) torch.Size([32])\n",
"torch.Size([32, 1, 28, 28]) torch.Size([32])\n",
"torch.Size([32, 1, 28, 28]) torch.Size([32])\n",
"torch.Size([32, 1, 28, 28]) torch.Size([32])\n",
"torch.Size([32, 1, 28, 28]) torch.Size([32])\n",
"torch.Size([32, 1, 28, 28]) torch.Size([32])\n",
"torch.Size([32, 1, 28, 28]) torch.Size([32])\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1008x216 with 10 Axes>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EGtnadiqJ2Nu"
},
"source": [
"## Implementando nossa arquitetura!\n",
"\n",
"Defina a arquitetura da sua Rede Neural. O pacote torch.nn que contém as implementações de todas as camadas que serão usadas nessa parte (nn.Linear): \n",
"\n",
"- https://pytorch.org/docs/stable/nn.html"
]
},
{
"cell_type": "code",
"metadata": {
"id": "P4oFGFAKJ2Nv",
"outputId": "f456d924-0520-4680-d7c1-6c1e930e7397",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"class MinhaRede(nn.Module):\n",
" \n",
" def __init__(self, input_size, hidden_layers, n_classes):\n",
"\n",
" super(MinhaRede, self).__init__()\n",
"\n",
" '''\n",
" Exercício 1.1: Construa a sua rede neural. Ela deve ter entre 2\n",
" e 4 camadas escondidas e ir diminuindo o número de features (hidden size)\n",
" progressivamente ao longo essas camadas. Cada camada intermediária deve\n",
" receber o output da última camada. A primeira camada deve receber\n",
" input_size features e a última camada deve gerar n_classes features que\n",
" farão as predições entre as classes do MNIST. Adicione ativações ReLU\n",
" após todas as camadas nn.Linear, menos na última, que não deve ter\n",
" nenhuma ativação.\n",
" '''\n",
" # ###################################################\n",
" # # Neural Network architecture. 2~4 hidden layers. #\n",
" # ###################################################\n",
" # self.layer1 = # TO DO... Criar primeira camada linear.\n",
" # self.ativ1 = # TO DO... Criar ativação da primeira camada.\n",
" # self.layer2 = # TO DO... Criar segunda camada linear.\n",
" # self.ativ2 = # TO DO... Criar ativação da segunda camada.\n",
"\n",
" # # TO DO... Criar outras camadas/ativações, se achar necessário\n",
"\n",
" # Definir a arquitetura\n",
" self.layer1 = nn.Linear(input_size, hidden_layers[0])\n",
" self.ativ1 = nn.ReLU()\n",
"\n",
" self.layer2 = nn.Linear(hidden_layers[0], hidden_layers[1])\n",
" self.ativ2 = nn.ReLU()\n",
"\n",
" self.layer3 = nn.Linear(hidden_layers[1], hidden_layers[2])\n",
" self.ativ3 = nn.ReLU()\n",
"\n",
" self.output = nn.Linear(hidden_layers[2], n_classes)\n",
"\n",
" # Forward function.\n",
" def forward(self, x):\n",
"\n",
" '''\n",
" Exercício 1.2: Alimente os dados para a camada da sua rede. Para\n",
" alimentar um dado x para uma certa camada self.camada_n:\n",
" saida = self.camada_n(x). \n",
" Lembre-se de verificar a dimensão do dado de entrada, ela deve estar\n",
" na forma (B, F), no nosso caso (32, 784).\n",
" '''\n",
" # ###################################################\n",
" # # Forwarding through all layers. ##################\n",
" # ###################################################\n",
" # out = # TO DO... Passar o input x sequencialmente pelas camadas da rede.\n",
"\n",
" tns_flat = x.view(x.size(0), -1)\n",
"\n",
" layer1 = self.ativ1(self.layer1(tns_flat))\n",
" layer2 = self.ativ2(self.layer2(layer1))\n",
" layer3 = self.ativ3(self.layer3(layer2))\n",
" out = self.output(layer3)\n",
"\n",
" # Returning output.\n",
" return out # TO DO... Lembre-se de sempre retornar a saída da última camada.\n",
" \n",
"# Instancing Network.\n",
"input_size = 784 # Input size (number of features).\n",
"n_classes = 10 # Number of classes.\n",
"hidden_layers = [64, 32, 16]\n",
"model = MinhaRede(input_size, hidden_layers, n_classes).to(device) # GPU casting.\n",
"\n",
"# Printing NN.\n",
"print(model)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"MinhaRede(\n",
" (layer1): Linear(in_features=784, out_features=64, bias=True)\n",
" (ativ1): ReLU()\n",
" (layer2): Linear(in_features=64, out_features=32, bias=True)\n",
" (ativ2): ReLU()\n",
" (layer3): Linear(in_features=32, out_features=16, bias=True)\n",
" (ativ3): ReLU()\n",
" (output): Linear(in_features=16, out_features=10, bias=True)\n",
")\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ivyz5OiW-VS4"
},
"source": [
"Teste aqui se a sua rede está criando a saída desejada (B, C)\n",
"- B: `batch_size` \n",
"- C: número de classes"
]
},
{
"cell_type": "code",
"metadata": {
"id": "tK2wGwsG-UOs",
"outputId": "7f367e3c-a128-40ef-cf60-aa611cabe332",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"k = -1\n",
"for data, label in test_loader:\n",
" k += 1\n",
" if k > 2: break\n",
"\n",
" data = data.to(device)\n",
" print(data.size())\n",
" saida = model(data)\n",
" print(saida.size())"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"torch.Size([32, 1, 28, 28])\n",
"torch.Size([32, 10])\n",
"torch.Size([32, 1, 28, 28])\n",
"torch.Size([32, 10])\n",
"torch.Size([32, 1, 28, 28])\n",
"torch.Size([32, 10])\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MlBTGdzPJ2N1"
},
"source": [
"## Definindo um critério de qualidade (Loss)\n",
"\n",
"O primeiro passo é instanciar a função de perda de sua escolha. Trata-se de um problema de classificação com 3 classes, nesse caso a Cross Entropy é a função recomendada, que no PyTorch recebe o nome de ```nn.CrossEntropyLoss()```: https://pytorch.org/docs/stable/nn.html#crossentropyloss \n",
"\n",
"**Assim como a rede, as entradas e os rótulos, a função de perda também deve ser carregada na GPU**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "H1WirBm_J2N2"
},
"source": [
"# Setting classification loss.\n",
"'''\n",
"Exercício 3: Defina uma loss function. Lembre-se de fazer o casting da loss para\n",
"a GPU assim como fizemos na rede.\n",
"'''\n",
"criterion = nn.CrossEntropyLoss().to(device)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "VOVXXfCiJ2Ny"
},
"source": [
"## Instanciando o Otimizador\n",
"\n",
"> Pacote ```torch.optim```\n",
"\n",
"Mãos a obra! Vamos agora otimizar a nossa rede usando os algoritmos mais tradicionais da área. Para isso, a biblioteca ```torch.optim``` nos será bem útil, pois ela implementa os principais algoritmos de otimização de redes neurais.\n",
"\n",
"O primeiro passo é instanciar o otimizador. De acordo com o pacote ```optim```, basta chamar o otimizador escolhido, passando como parâmetro:\n",
"* Os parâmetros da rede que será otimizada (```params = model.parameters()```)\n",
"* A taxa de aprendizado (```lr```)\n",
"* A regularização (```weight_decay```)\n",
"\n",
"A depender do otimizador, pode ser necessário alimentar outros parâmetros, mas esses quase sempre são obrigatórios!\n",
"\n",
"O $Pytorch$ tem várias opções de otimizadores, desde os mais simples como o SGD até adaptadores mais modernos com velocidades de aprendizado adaptáveis para cada parâmetro da rede (i.e. Adam, Adagrad, RSMProp...). Todos os otimizadores estão localizados no pacote torch.optim. \n",
"\n",
"Para mais informnações sobre o pacote, visite: <https://pytorch.org/docs/stable/optim.html>."
]
},
{
"cell_type": "code",
"metadata": {
"id": "fhyopppRJ2Nz"
},
"source": [
"import torch.optim as optim\n",
"\n",
"lr = 0.0001 # TO DO se quiser, altere a learning rate\n",
"regularizer = 0.00005 # L2 Normalization via weight decay.\n",
"\n",
"'''\n",
"Exercício 2: Defina um otimizador. Utilize um otimizador mais poderoso \n",
"como o Adam. Se quiser, experimente diferentes learning rates para o \n",
"otimizador e observe como isso afeta a otimização da loss.\n",
"'''\n",
"## TO DO defina um otimizador\n",
"optimizer = optim.Adam(params=model.parameters(), lr=lr, weight_decay=regularizer)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "d9uC1yHfJ2N4"
},
"source": [
"# Fluxo de Treinamento & Validação\n",
"\n",
"## Treinamento\n",
"\n",
"Passo a passo do fluxo de treinamento (uma época):\n",
"* Iterar nos batches\n",
"* Cast dos dados no dispositivo de hardware\n",
"* Forward na rede e cálculo da loss\n",
"* Zerar o gradiente do otimizador\n",
"* Calcular o gradiente da variável loss\n",
"* Atualizar dos pesos do modelo com o otimizador\n",
"\n",
"Esse conjunto de passos é responsável pelo processo iterativo de otimização de uma rede. **A validação** por outro lado, é apenas a aplicação da rede em dados nunca antes visto para estimar a qualidade do modelo no mundo real.\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "04-4qarbJ2N5"
},
"source": [
"'''\n",
"Exercício 4: Complete o código da função a seguir para que ela implemente \n",
"um fluxo de treinamento (descrição acima).\n",
"'''\n",
"\n",
"def train(train_loader, net, epoch):\n",
"\n",
" # Modo de treinamento\n",
" net.train()\n",
" \n",
" start = time.time()\n",
" \n",
" epoch_loss = []\n",
"\n",
" # Para cálculo da acurácia\n",
" pred_list, label_list = [], []\n",
"\n",
" ## Iterando nos batches\n",
" for batch in train_loader:\n",
" \n",
" dado, rotulo = batch\n",
" \n",
" ## TO DO: Coloque os dados na GPU (variável device)\n",
" dado = dado.to(device)\n",
" rotulo = rotulo.to(device)\n",
" \n",
" ## TO DO: realize o forward dos dados na rede\n",
" output = net(dado)\n",
"\n",
" ## TO DO: calcule a loss e dê append na lista epoch_loss \n",
" loss = criterion(output, rotulo)\n",
" epoch_loss.append(loss.cpu().data)\n",
" \n",
" ## TO DO: zerar o gradiente da loss\n",
" optimizer.zero_grad()\n",
"\n",
" ## TO DO: calcule o gradiente da loss\n",
" loss.backward()\n",
"\n",
" ## TO DO: Dê um passo de otimização\n",
" optimizer.step()\n",
"\n",
" '''\n",
" Exercício 6: Quando terminar tudo e perceber que a sua loss está\n",
" reduzindo ao longo do treinamento, descomente todas as linhas a seguir\n",
" substituindo ypred por sua variável resposta do modelo e acompanhe \n",
" a acurácia da sua proposta de rede neural.\n",
" Faça isso também na função validate().\n",
" '''\n",
" # _, pred = torch.max(ypred, dim=1)\n",
" # pred_list.append(pred.cpu().numpy())\n",
" # label_list.append(rotulo.cpu().numpy())\n",
"\n",
" epoch_loss = np.asarray(epoch_loss)\n",
"\n",
"\n",
" # acc = metrics.accuracy_score(np.asarray(label_list).ravel(),\n",
" # np.asarray(pred_list).ravel())\n",
" \n",
" end = time.time()\n",
" print('#################### Train ####################')\n",
" print('Epoch %d, Loss: %.4f +/- %.4f, Time: %.2f' % (epoch, epoch_loss.mean(), epoch_loss.std(), end-start))\n",
" # print('------- Acc: %.2f'%(acc*100))\n",
" \n",
" return epoch_loss.mean()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "fg4O4lArAMfR"
},
"source": [
"## Validação\n",
"\n",
"Passo a passo do fluxo de validação (uma época):\n",
"* Iterar nos batches\n",
"* Cast dos dados no dispositivo de hardware\n",
"* Forward na rede e cálculo da loss\n",
"* ~~Zerar o gradiente do otimizador~~\n",
"* ~~Calcular o gradiente da variável loss~~\n",
"* ~~Atualizar dos pesos do modelo com o otimizador~~\n",
"\n",
"Para essa etapa, o PyTorch oferece dois artifícios:\n",
"* ```model.eval()```: Impacta no *forward* da rede, informando as camadas caso seu comportamento mude entre fluxos (ex: dropout).\n",
"* ```with torch.no_grad()```: Gerenciador de contexto que desabilita o cálculo e armazenamento de gradientes (economia de tempo e memória). Todo o código de validação deve ser executado dentro desse contexto.\n",
"\n",
"Exemplo de código para validação\n",
"\n",
"```python\n",
"net.eval()\n",
"with torch.no_grad():\n",
" for batch in test_loader:\n",
" # Código de validação\n",
"```\n",
"\n",
"Existe o equivalente ao ```model.eval()``` para explicitar que a sua rede deve estar em modo de treino, é o ```model.train()```. Apesar de ser o padrão dos modelos, é boa prática definir também o modo de treinamento."
]
},
{
"cell_type": "code",
"metadata": {
"id": "g6O62-6cAOP_"
},
"source": [
"'''\n",
"Exercício 5: Complete o código da função a seguir para que ela implemente \n",
"um fluxo de validação (descrição acima).\n",
"'''\n",
"\n",
"def validate(test_loader, net, epoch):\n",
"\n",
" # Modo de teste\n",
" net.eval()\n",
" \n",
" start = time.time()\n",
" \n",
" epoch_loss = []\n",
"\n",
" # Para cálculo da acurácia\n",
" pred_list, label_list = [], []\n",
" \n",
" with torch.no_grad(): \n",
" ## Iterando nos batches\n",
" for batch in test_loader:\n",
"\n",
" dado, rotulo = batch\n",
"\n",
" ## TO DO: Coloque os dados na GPU (variável device)\n",
" dado = dado.to(device)\n",
" rotulo = rotulo.to(device)\n",
"\n",
" ## TO DO: realize o forward dos dados na rede\n",
" output = net(dado)\n",
"\n",
" ## TO DO: calcule a loss e dê append na lista epoch_loss \n",
" loss = criterion(output, rotulo)\n",
" epoch_loss.append(loss.cpu().data)\n",
"\n",
" _, pred = torch.max(output, dim=1)\n",
" pred_list.append(pred.cpu().numpy())\n",
" label_list.append(rotulo.cpu().numpy())\n",
"\n",
" epoch_loss = np.asarray(epoch_loss)\n",
"\n",
" acc = metrics.accuracy_score(np.asarray(label_list).ravel(),\n",
" np.asarray(pred_list).ravel())\n",
" \n",
" end = time.time()\n",
" print('********** Validate **********')\n",
" print('Epoch %d, Loss: %.4f +/- %.4f, Time: %.2f' % (epoch, epoch_loss.mean(), epoch_loss.std(), end-start))\n",
" print('------- Acc: %.2f\\n'%(acc*100))\n",
" \n",
" return epoch_loss.mean()\n",
" "
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "q7uESV9lnUSn",
"outputId": "27ed2c02-b98a-4032-f8ce-057d3847460a",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"train_losses, test_losses = [], []\n",
"num_epochs = 20\n",
"for epoch in range(num_epochs):\n",
" \n",
" # Train\n",
" train_losses.append(train(train_loader, model, epoch))\n",
" \n",
" # Validate\n",
" test_losses.append(validate(test_loader, model, epoch))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"#################### Train ####################\n",
"Epoch 0, Loss: 0.0929 +/- 0.0833, Time: 8.57\n",
"********** Validate **********\n",
"Epoch 0, Loss: 0.0893 +/- 0.0826, Time: 5.02\n",
"------- Acc: 97.41\n",
"\n",
"#################### Train ####################\n",
"Epoch 1, Loss: 0.0896 +/- 0.0805, Time: 8.54\n",
"********** Validate **********\n",
"Epoch 1, Loss: 0.0840 +/- 0.0781, Time: 4.96\n",
"------- Acc: 97.59\n",
"\n",
"#################### Train ####################\n",
"Epoch 2, Loss: 0.0856 +/- 0.0760, Time: 8.52\n",
"********** Validate **********\n",
"Epoch 2, Loss: 0.0800 +/- 0.0759, Time: 5.27\n",
"------- Acc: 97.71\n",
"\n",
"#################### Train ####################\n",
"Epoch 3, Loss: 0.0827 +/- 0.0760, Time: 8.74\n",
"********** Validate **********\n",
"Epoch 3, Loss: 0.0767 +/- 0.0739, Time: 5.10\n",
"------- Acc: 97.86\n",
"\n",
"#################### Train ####################\n",
"Epoch 4, Loss: 0.0794 +/- 0.0739, Time: 8.28\n",
"********** Validate **********\n",
"Epoch 4, Loss: 0.0749 +/- 0.0736, Time: 4.96\n",
"------- Acc: 97.88\n",
"\n",
"#################### Train ####################\n",
"Epoch 5, Loss: 0.0762 +/- 0.0728, Time: 8.41\n",
"********** Validate **********\n",
"Epoch 5, Loss: 0.0732 +/- 0.0722, Time: 4.93\n",
"------- Acc: 97.89\n",
"\n",
"#################### Train ####################\n",
"Epoch 6, Loss: 0.0737 +/- 0.0723, Time: 8.69\n",
"********** Validate **********\n",
"Epoch 6, Loss: 0.0700 +/- 0.0708, Time: 5.06\n",
"------- Acc: 98.00\n",
"\n",
"#################### Train ####################\n",
"Epoch 7, Loss: 0.0711 +/- 0.0715, Time: 8.35\n",
"********** Validate **********\n",
"Epoch 7, Loss: 0.0672 +/- 0.0691, Time: 5.10\n",
"------- Acc: 98.06\n",
"\n",
"#################### Train ####################\n",
"Epoch 8, Loss: 0.0684 +/- 0.0690, Time: 8.29\n",
"********** Validate **********\n",
"Epoch 8, Loss: 0.0640 +/- 0.0663, Time: 4.93\n",
"------- Acc: 98.20\n",
"\n",
"#################### Train ####################\n",
"Epoch 9, Loss: 0.0662 +/- 0.0672, Time: 8.68\n",
"********** Validate **********\n",
"Epoch 9, Loss: 0.0620 +/- 0.0652, Time: 5.21\n",
"------- Acc: 98.27\n",
"\n",
"#################### Train ####################\n",
"Epoch 10, Loss: 0.0640 +/- 0.0649, Time: 8.41\n",
"********** Validate **********\n",
"Epoch 10, Loss: 0.0610 +/- 0.0648, Time: 4.92\n",
"------- Acc: 98.28\n",
"\n",
"#################### Train ####################\n",
"Epoch 11, Loss: 0.0622 +/- 0.0654, Time: 8.44\n",
"********** Validate **********\n",
"Epoch 11, Loss: 0.0564 +/- 0.0618, Time: 4.85\n",
"------- Acc: 98.44\n",
"\n",
"#################### Train ####################\n",
"Epoch 12, Loss: 0.0598 +/- 0.0643, Time: 8.77\n",
"********** Validate **********\n",
"Epoch 12, Loss: 0.0554 +/- 0.0613, Time: 5.04\n",
"------- Acc: 98.49\n",
"\n",
"#################### Train ####################\n",
"Epoch 13, Loss: 0.0582 +/- 0.0625, Time: 8.73\n",
"********** Validate **********\n",
"Epoch 13, Loss: 0.0536 +/- 0.0599, Time: 4.96\n",
"------- Acc: 98.50\n",
"\n",
"#################### Train ####################\n",
"Epoch 14, Loss: 0.0559 +/- 0.0621, Time: 8.35\n",
"********** Validate **********\n",
"Epoch 14, Loss: 0.0512 +/- 0.0588, Time: 5.13\n",
"------- Acc: 98.62\n",
"\n",
"#################### Train ####################\n",
"Epoch 15, Loss: 0.0541 +/- 0.0592, Time: 8.72\n",
"********** Validate **********\n",
"Epoch 15, Loss: 0.0504 +/- 0.0583, Time: 5.20\n",
"------- Acc: 98.62\n",
"\n",
"#################### Train ####################\n",
"Epoch 16, Loss: 0.0529 +/- 0.0578, Time: 9.12\n",
"********** Validate **********\n",
"Epoch 16, Loss: 0.0487 +/- 0.0570, Time: 5.03\n",
"------- Acc: 98.66\n",
"\n",
"#################### Train ####################\n",
"Epoch 17, Loss: 0.0507 +/- 0.0565, Time: 8.45\n",
"********** Validate **********\n",
"Epoch 17, Loss: 0.0487 +/- 0.0570, Time: 5.31\n",
"------- Acc: 98.63\n",
"\n",
"#################### Train ####################\n",
"Epoch 18, Loss: 0.0496 +/- 0.0584, Time: 8.84\n",
"********** Validate **********\n",
"Epoch 18, Loss: 0.0463 +/- 0.0550, Time: 4.86\n",
"------- Acc: 98.70\n",
"\n",
"#################### Train ####################\n",
"Epoch 19, Loss: 0.0478 +/- 0.0553, Time: 8.47\n",
"********** Validate **********\n",
"Epoch 19, Loss: 0.0423 +/- 0.0525, Time: 5.01\n",
"------- Acc: 98.91\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "t8iWQH_grcP-",
"outputId": "ee1b90a3-e07f-4ac8-ef5a-d34f409fa68f",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"train_losses, test_losses = [], []\n",
"num_epochs = 20\n",
"for epoch in range(num_epochs):\n",
" \n",
" # Train\n",
" train_losses.append(train(train_loader, model, epoch))\n",
" \n",
" # Validate\n",
" test_losses.append(validate(test_loader, model, epoch))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"#################### Train ####################\n",
"Epoch 0, Loss: 0.9757 +/- 0.6015, Time: 8.81\n",
"********** Validate **********\n",
"Epoch 0, Loss: 0.4648 +/- 0.1691, Time: 5.05\n",
"#################### Train ####################\n",
"Epoch 1, Loss: 0.4082 +/- 0.1502, Time: 8.83\n",
"********** Validate **********\n",
"Epoch 1, Loss: 0.3634 +/- 0.1734, Time: 5.09\n",
"#################### Train ####################\n",
"Epoch 2, Loss: 0.3403 +/- 0.1501, Time: 8.92\n",
"********** Validate **********\n",
"Epoch 2, Loss: 0.3127 +/- 0.1656, Time: 4.94\n",
"#################### Train ####################\n",
"Epoch 3, Loss: 0.2976 +/- 0.1425, Time: 8.86\n",
"********** Validate **********\n",
"Epoch 3, Loss: 0.2771 +/- 0.1542, Time: 5.10\n",
"#################### Train ####################\n",
"Epoch 4, Loss: 0.2658 +/- 0.1345, Time: 8.79\n",
"********** Validate **********\n",
"Epoch 4, Loss: 0.2476 +/- 0.1467, Time: 4.77\n",
"#################### Train ####################\n",
"Epoch 5, Loss: 0.2392 +/- 0.1267, Time: 8.71\n",
"********** Validate **********\n",
"Epoch 5, Loss: 0.2287 +/- 0.1383, Time: 4.81\n",
"#################### Train ####################\n",
"Epoch 6, Loss: 0.2176 +/- 0.1233, Time: 8.59\n",
"********** Validate **********\n",
"Epoch 6, Loss: 0.2037 +/- 0.1300, Time: 4.84\n",
"#################### Train ####################\n",
"Epoch 7, Loss: 0.1985 +/- 0.1140, Time: 8.54\n",
"********** Validate **********\n",
"Epoch 7, Loss: 0.1862 +/- 0.1252, Time: 4.98\n",
"#################### Train ####################\n",
"Epoch 8, Loss: 0.1823 +/- 0.1122, Time: 8.44\n",
"********** Validate **********\n",
"Epoch 8, Loss: 0.1726 +/- 0.1199, Time: 4.89\n",
"#################### Train ####################\n",
"Epoch 9, Loss: 0.1689 +/- 0.1100, Time: 8.54\n",
"********** Validate **********\n",
"Epoch 9, Loss: 0.1597 +/- 0.1125, Time: 4.84\n",
"#################### Train ####################\n",
"Epoch 10, Loss: 0.1573 +/- 0.1045, Time: 8.71\n",
"********** Validate **********\n",
"Epoch 10, Loss: 0.1488 +/- 0.1102, Time: 4.75\n",
"#################### Train ####################\n",
"Epoch 11, Loss: 0.1470 +/- 0.1018, Time: 8.50\n",
"********** Validate **********\n",
"Epoch 11, Loss: 0.1411 +/- 0.1074, Time: 5.09\n",
"#################### Train ####################\n",
"Epoch 12, Loss: 0.1384 +/- 0.0982, Time: 8.73\n",
"********** Validate **********\n",
"Epoch 12, Loss: 0.1313 +/- 0.1019, Time: 4.71\n",
"#################### Train ####################\n",
"Epoch 13, Loss: 0.1305 +/- 0.0950, Time: 8.43\n",
"********** Validate **********\n",
"Epoch 13, Loss: 0.1250 +/- 0.0981, Time: 4.70\n",
"#################### Train ####################\n",
"Epoch 14, Loss: 0.1237 +/- 0.0945, Time: 8.70\n",
"********** Validate **********\n",
"Epoch 14, Loss: 0.1190 +/- 0.0944, Time: 5.21\n",
"#################### Train ####################\n",
"Epoch 15, Loss: 0.1172 +/- 0.0895, Time: 8.35\n",
"********** Validate **********\n",
"Epoch 15, Loss: 0.1106 +/- 0.0923, Time: 4.90\n",
"#################### Train ####################\n",
"Epoch 16, Loss: 0.1115 +/- 0.0887, Time: 8.47\n",
"********** Validate **********\n",
"Epoch 16, Loss: 0.1048 +/- 0.0892, Time: 4.90\n",
"#################### Train ####################\n",
"Epoch 17, Loss: 0.1062 +/- 0.0877, Time: 8.80\n",
"********** Validate **********\n",
"Epoch 17, Loss: 0.0993 +/- 0.0866, Time: 4.78\n",
"#################### Train ####################\n",
"Epoch 18, Loss: 0.1018 +/- 0.0858, Time: 8.30\n",
"********** Validate **********\n",
"Epoch 18, Loss: 0.0953 +/- 0.0840, Time: 4.87\n",
"#################### Train ####################\n",
"Epoch 19, Loss: 0.0972 +/- 0.0842, Time: 8.76\n",
"********** Validate **********\n",
"Epoch 19, Loss: 0.0911 +/- 0.0822, Time: 4.84\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "fI8hW3CrkepT"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment