Last active
August 23, 2018 13:18
-
-
Save vfdev-5/57a7df109590e195800d980b1dfafb4c to your computer and use it in GitHub Desktop.
57a7df109590e195800d980b1dfafb4c
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "_uuid": "e6e64edb37cf58e6caca8b7d98252e72fb6978f5" | |
| }, | |
| "source": [ | |
| "# Fruits 360 dataset with pytorch/ignite\n", | |
| "\n", | |
| "<img src=\"https://pytorch.org/docs/stable/_static/pytorch-logo-dark.svg\" width=\"120\"> & <img src=\"https://pytorch.org/ignite/_static/ignite-logo-dark.svg\" width=\"80\">\n", | |
| "\n", | |
| "\n", | |
| "In this kernel I would like to present recently released the first version of high-level library [*ignite*](https://github.com/pytorch/ignite) to help training neural networks in PyTorch.\n", | |
| "\n", | |
| "\n", | |
| "## Why to use *ignite* ?\n", | |
| "\n", | |
| "- ignite helps you write compact but full-featured training loops in a few lines of code\n", | |
| "- you get a training loop with metrics, early-stopping, model checkpointing and other features without the boilerplate\n", | |
| "\n", | |
| "\n", | |
| "## Installation\n", | |
| "\n", | |
| "Just run the following command:\n", | |
| "```bash\n", | |
| "pip install pytorch-ignite\n", | |
| "```\n", | |
| "or with conda\n", | |
| "```bash\n", | |
| "conda install ignite -c pytorch\n", | |
| "```" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "_uuid": "85f35e858fc5884886611c5ecce8e2f3bab8eaac" | |
| }, | |
| "source": [ | |
| "Before we starts with *ignite*, let's define essential things: \n", | |
| "- dataflow :\n", | |
| " - train data loader\n", | |
| " - validation data loader\n", | |
| "- model :\n", | |
| " - let's take a small network SqueezeNet \n", | |
| "- optimizer : \n", | |
| " - let's take SGD\n", | |
| "- loss function :\n", | |
| " - Cross-Entropy" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Dataset can be downloaded from [kaggle datasets](https://www.kaggle.com/moltean/fruits)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "_kg_hide-input": true, | |
| "_uuid": "085338e05b35364bdfad806f8554810366b9588b" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from pathlib import Path\n", | |
| "\n", | |
| "import numpy as np\n", | |
| "\n", | |
| "import torch\n", | |
| "from torch.utils.data import Dataset, DataLoader\n", | |
| "from torch.utils.data.dataset import Subset\n", | |
| "from torchvision.datasets import ImageFolder\n", | |
| "from torchvision.transforms import Compose, RandomResizedCrop, RandomVerticalFlip, RandomHorizontalFlip\n", | |
| "from torchvision.transforms import ColorJitter, ToTensor, Normalize\n", | |
| "\n", | |
| "\n", | |
| "FRUIT360_PATH = Path(\".\").resolve().parent / \"input\" / \"fruits-360_dataset_2018_05_26\" / \"fruits-360\"\n", | |
| "\n", | |
| "device = \"cuda\"\n", | |
| "if not torch.cuda.is_available():\n", | |
| " device = \"cpu\"\n", | |
| "\n", | |
| "train_transform = Compose([\n", | |
| " RandomHorizontalFlip(), \n", | |
| " RandomResizedCrop(size=32),\n", | |
| " ColorJitter(brightness=0.12),\n", | |
| " ToTensor(),\n", | |
| " Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n", | |
| "])\n", | |
| "\n", | |
| "val_transform = Compose([\n", | |
| " RandomResizedCrop(size=32),\n", | |
| " ToTensor(),\n", | |
| " Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n", | |
| "])\n", | |
| "\n", | |
| "batch_size = 16\n", | |
| "num_workers = 4\n", | |
| "\n", | |
| "train_dataset = ImageFolder((FRUIT360_PATH /\"Training\").as_posix(), transform=train_transform, target_transform=None)\n", | |
| "val_dataset = ImageFolder((FRUIT360_PATH /\"Validation\").as_posix(), transform=val_transform, target_transform=None)\n", | |
| "\n", | |
| "# For demo purposes and due to slow CPU computation I reduce the size of training and validation datasets\n", | |
| "warning_msg = \"\"\n", | |
| "if \"cpu\" in device:\n", | |
| " warning_msg = \"Datasets are reduced as we use CPU\"\n", | |
| " train_dataset = Subset(train_dataset, np.arange(100))\n", | |
| " val_dataset = Subset(train_dataset, np.arange(30))\n", | |
| "\n", | |
| "\n", | |
| "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, \n", | |
| " drop_last=True, pin_memory=\"cuda\" in device)\n", | |
| "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, \n", | |
| " drop_last=False, pin_memory=\"cuda\" in device)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": { | |
| "_uuid": "c63174435b431d5cda17a8ea49f0f0f5f653e2e0" | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "PyTorch version: 0.4.0 | Device: cuda | \n", | |
| "Train loader: num_batches=1980 | num_samples=31688\n", | |
| "Validation loader: num_batches=667 | num_samples=10657\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print(\"PyTorch version: {} | Device: {} | {}\".format(torch.__version__, device, warning_msg))\n", | |
| "print(\"Train loader: num_batches={} | num_samples={}\".format(len(train_loader), len(train_loader.sampler)))\n", | |
| "print(\"Validation loader: num_batches={} | num_samples={}\".format(len(val_loader), len(val_loader.sampler)))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": { | |
| "_kg_hide-input": true, | |
| "_uuid": "59cad824744cdc334b0074535d053cda6a0bdd9e", | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import torch.nn as nn\n", | |
| "from torchvision.models.squeezenet import squeezenet1_1\n", | |
| "from torch.optim import SGD" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": { | |
| "_uuid": "96b06dffc097196718ec7ea799146b2fed6ef829" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "model = squeezenet1_1(pretrained=False, num_classes=64)\n", | |
| "model.classifier[-1] = nn.AdaptiveAvgPool2d(1) # Adapt the last average pooling to our data\n", | |
| "model = model.to(device)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": { | |
| "_uuid": "d0e3456ae912799f526d2076a48076b3b8b871a2", | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "optimizer = SGD(model.parameters(), lr=0.01, momentum=0.5)\n", | |
| "criterion = nn.CrossEntropyLoss().to(device)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "_uuid": "887968154e1949db79cadc6f2b7de37518e33c4d" | |
| }, | |
| "source": [ | |
| "And let us begin\n", | |
| "\n", | |
| "## Ignite quickstart with Fruits 360 dataset\n", | |
| "\n", | |
| "### Engine\n", | |
| "\n", | |
| "The base of the framework is `ignite.engine.Engine`, an object that loops a given number of times over provided data, executes a processing function and returns a result:\n", | |
| "```python\n", | |
| "while epoch < max_epochs:\n", | |
| " # run once on data\n", | |
| " for batch in data:\n", | |
| " output = process_function(batch)\n", | |
| "```\n", | |
| "\n", | |
| "So, a model trainer is simply an engine that loops multiple times over the training dataset and updates model parameters. \n", | |
| "Similarly, model evaluation can be done with an engine that runs a single time over the validation dataset and computes metrics." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": { | |
| "_uuid": "81769c0904de73b324ffa616fd6e66bdfbaf78be", | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from ignite.engine import Engine, _prepare_batch, create_supervised_trainer\n", | |
| "\n", | |
| "def model_update(engine, batch):\n", | |
| " model.train()\n", | |
| " optimizer.zero_grad()\n", | |
| " x, y = _prepare_batch(batch, device=device)\n", | |
| " y_pred = model(x)\n", | |
| " loss = criterion(y_pred, y)\n", | |
| " loss.backward()\n", | |
| " optimizer.step()\n", | |
| " return loss.item()\n", | |
| "\n", | |
| "trainer = Engine(model_update)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "_uuid": "26c84d93884195bde229671f9187e39977b059bb" | |
| }, | |
| "source": [ | |
| "and that's it. A trainer is setup, so we can just simply execute `run` method and our model will be silently trained. We could also use a helper method `ignite.engine.create_supervised_trainer` to create a trainer without explicitly coding `model_update` function:\n", | |
| "```python\n", | |
| "from ignite.engine import create_supervised_trainer\n", | |
| "\n", | |
| "trainer = create_supervised_trainer(model, optimizer, criterion, device)\n", | |
| "```\n", | |
| "\n", | |
| "\n", | |
| "> **Note:** update function should have two inputs : `engine` and `batch`\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "Let's add more interaction with our created trainer:\n", | |
| "- add logging of loss function value every 50 iterations\n", | |
| "- run offline metrics computation on a subset of the training dataset\n", | |
| "- run metrics computation on the validation dataset once epoch is finished\n", | |
| "- checkpoint trained model every epoch\n", | |
| "- save 3 best models\n", | |
| "- add LR scheduling\n", | |
| "- add early stopping" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0", | |
| "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a" | |
| }, | |
| "source": [ | |
| "### Events and Handlers\n", | |
| "\n", | |
| "In order to accomplish above todo list *ignite* provides an event system that facilitates interaction at each step of the run:\n", | |
| "- *engine is started/completed*\n", | |
| "- *epoch is started/completed*\n", | |
| "- *batch iteration is started/completed*\n", | |
| "\n", | |
| "So that user can execute a custom code as an event handler.\n", | |
| "\n", | |
| "#### Training batch loss logging\n", | |
| "\n", | |
| "We just define a function and add this function as a handler to the trainer. There are two ways to add a handler: via `add_event_handler`, via `on` decorator:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": { | |
| "_uuid": "1b38f21bcca324188747acc2cfba5d8ad6b42c9b", | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from ignite.engine import Events\n", | |
| "\n", | |
| "log_interval = 50\n", | |
| "\n", | |
| "@trainer.on(Events.ITERATION_COMPLETED)\n", | |
| "def log_training_loss(engine):\n", | |
| " iteration = (engine.state.iteration - 1) % len(train_loader) + 1\n", | |
| " if iteration % log_interval == 0:\n", | |
| " print(\"Epoch[{}] Iteration[{}/{}] Loss: {:.4f}\".format(engine.state.epoch, iteration, len(train_loader), engine.state.output))\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "_uuid": "0408630fbed51cb94f74100b170c51d3dc1b0dc2" | |
| }, | |
| "source": [ | |
| "The same can be done with `add_event_handler` like this:\n", | |
| "```python\n", | |
| "trainer.add_event_handler(Events.ITERATION_COMPLETED, log_training_loss)\n", | |
| "```\n", | |
| "\n", | |
| "\n", | |
| "> **Note:** handlers can also pass `args` and `kwargs`, so in general a handler can be defined as \n", | |
| "\n", | |
| "```python\n", | |
| " def custom_handler(engine, *args, **kwargs):\n", | |
| " pass\n", | |
| "\n", | |
| " trainer.add_event_handler(Events.ITERATION_COMPLETED, custom_handler, *args, **kwargs)\n", | |
| " # or \n", | |
| " @trainer.on(Events.ITERATION_COMPLETED, *args, **kwargs)\n", | |
| " def custom_handler(engine, *args, **kwargs):\n", | |
| " pass\n", | |
| "```\n", | |
| "\n", | |
| "Let's see what happens if we run the trainer for a single epoch" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": { | |
| "_uuid": "e8e38ea4e67cf0bd1dcca5420eef885358940e9e", | |
| "collapsed": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch[1] Iteration[50/1980] Loss: 4.1834\n", | |
| "Epoch[1] Iteration[100/1980] Loss: 4.1233\n", | |
| "Epoch[1] Iteration[150/1980] Loss: 4.1541\n", | |
| "Epoch[1] Iteration[200/1980] Loss: 4.1283\n", | |
| "Epoch[1] Iteration[250/1980] Loss: 4.2098\n", | |
| "Epoch[1] Iteration[300/1980] Loss: 4.0575\n", | |
| "Epoch[1] Iteration[350/1980] Loss: 3.9282\n", | |
| "Epoch[1] Iteration[400/1980] Loss: 3.9719\n", | |
| "Epoch[1] Iteration[450/1980] Loss: 4.1250\n", | |
| "Epoch[1] Iteration[500/1980] Loss: 3.7604\n", | |
| "Epoch[1] Iteration[550/1980] Loss: 4.1007\n", | |
| "Epoch[1] Iteration[600/1980] Loss: 3.8888\n", | |
| "Epoch[1] Iteration[650/1980] Loss: 3.8531\n", | |
| "Epoch[1] Iteration[700/1980] Loss: 3.8516\n", | |
| "Epoch[1] Iteration[750/1980] Loss: 3.7490\n", | |
| "Epoch[1] Iteration[800/1980] Loss: 3.3842\n", | |
| "Epoch[1] Iteration[850/1980] Loss: 3.1763\n", | |
| "Epoch[1] Iteration[900/1980] Loss: 3.9261\n", | |
| "Epoch[1] Iteration[950/1980] Loss: 3.5236\n", | |
| "Epoch[1] Iteration[1000/1980] Loss: 3.3979\n", | |
| "Epoch[1] Iteration[1050/1980] Loss: 3.4752\n", | |
| "Epoch[1] Iteration[1100/1980] Loss: 3.3541\n", | |
| "Epoch[1] Iteration[1150/1980] Loss: 3.4190\n", | |
| "Epoch[1] Iteration[1200/1980] Loss: 3.2390\n", | |
| "Epoch[1] Iteration[1250/1980] Loss: 2.9949\n", | |
| "Epoch[1] Iteration[1300/1980] Loss: 3.0091\n", | |
| "Epoch[1] Iteration[1350/1980] Loss: 3.3059\n", | |
| "Epoch[1] Iteration[1400/1980] Loss: 3.2347\n", | |
| "Epoch[1] Iteration[1450/1980] Loss: 2.8266\n", | |
| "Epoch[1] Iteration[1500/1980] Loss: 3.0389\n", | |
| "Epoch[1] Iteration[1550/1980] Loss: 2.7100\n", | |
| "Epoch[1] Iteration[1600/1980] Loss: 3.0158\n", | |
| "Epoch[1] Iteration[1650/1980] Loss: 2.8006\n", | |
| "Epoch[1] Iteration[1700/1980] Loss: 2.6677\n", | |
| "Epoch[1] Iteration[1750/1980] Loss: 3.6990\n", | |
| "Epoch[1] Iteration[1800/1980] Loss: 2.9626\n", | |
| "Epoch[1] Iteration[1850/1980] Loss: 3.2499\n", | |
| "Epoch[1] Iteration[1900/1980] Loss: 2.7257\n", | |
| "Epoch[1] Iteration[1950/1980] Loss: 2.7282\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "output = trainer.run(train_loader, max_epochs=1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "_uuid": "87f527b6b726e138b49267a6aeb3ffc8a9f679c2" | |
| }, | |
| "source": [ | |
| "Looks good!\n", | |
| "\n", | |
| "> add logging of loss function value every 50 iterations\n", | |
| "\n", | |
| "Done!\n", | |
| "\n", | |
| "#### Offline training metrics and validation metrics\n", | |
| "\n", | |
| "Now let's add some code to compute metrics: average accuracy, precision, recall over a subset of the training dataset and validation dataset. What is *offline* training metrics and why ? By offline, I mean that we compute training metrics using a fixed model vs online when metrics are computed batchwise over model that keep changing every iteration.\n", | |
| "\n", | |
| "At first we define metrics we want to compute:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": { | |
| "_uuid": "5ffbc59fca196242f32e38827b7f548a51a4754e", | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from ignite.metrics import Loss, CategoricalAccuracy, Precision, Recall\n", | |
| "\n", | |
| "\n", | |
| "metrics = {\n", | |
| " 'avg_loss': Loss(criterion),\n", | |
| " 'avg_accuracy': CategoricalAccuracy(),\n", | |
| " 'avg_precision': Precision(average=True), \n", | |
| " 'avg_recall': Recall(average=True)\n", | |
| "}" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "_uuid": "a903a90ce75b7d537d45cf789569e1b1b278c8b1" | |
| }, | |
| "source": [ | |
| "Next we can define engines using a helper method `ignite.engine.create_supervised_evaluator`:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": { | |
| "_uuid": "161d95906d2bb002365409564e7900f054dee28f", | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from ignite.engine import create_supervised_evaluator\n", | |
| "\n", | |
| "train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)\n", | |
| "val_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "_uuid": "96d190708d82bb588dcde135cd7736a47fdd76aa" | |
| }, | |
| "source": [ | |
| "and we need to define a train subset and its data loader:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": { | |
| "_uuid": "5a2d3f6d2687923585099ea643483c6be57d24bf", | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy as np\n", | |
| "from torch.utils.data.dataset import Subset\n", | |
| "\n", | |
| "random_indices = np.random.permutation(np.arange(len(train_dataset)))[:len(val_dataset)]\n", | |
| "train_subset = Subset(train_dataset, indices=random_indices)\n", | |
| "\n", | |
| "train_eval_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=num_workers, \n", | |
| " drop_last=True, pin_memory=\"cuda\" in device)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "_uuid": "3e59eec66fc083cbbf3ee601f7b16a3d90630d4e" | |
| }, | |
| "source": [ | |
| "Now let's define when to execute metrics computation and display results" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": { | |
| "_uuid": "ea523100429cf28ed88d29c323ca1c6255b454f4", | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "@trainer.on(Events.EPOCH_COMPLETED)\n", | |
| "def compute_and_display_offline_train_metrics(engine):\n", | |
| " epoch = engine.state.epoch\n", | |
| " print(\"Compute train metrics...\")\n", | |
| " metrics = train_evaluator.run(train_eval_loader).metrics\n", | |
| " print(\"Training Results - Epoch: {} Average Loss: {:.4f} | Accuracy: {:.4f} | Precision: {:.4f} | Recall: {:.4f}\"\n", | |
| " .format(engine.state.epoch, metrics['avg_loss'], metrics['avg_accuracy'], metrics['avg_precision'], metrics['avg_recall']))\n", | |
| " \n", | |
| " \n", | |
| "@trainer.on(Events.EPOCH_COMPLETED)\n", | |
| "def compute_and_display_val_metrics(engine):\n", | |
| " epoch = engine.state.epoch\n", | |
| " print(\"Compute validation metrics...\")\n", | |
| " metrics = val_evaluator.run(val_loader).metrics\n", | |
| " print(\"Validation Results - Epoch: {} Average Loss: {:.4f} | Accuracy: {:.4f} | Precision: {:.4f} | Recall: {:.4f}\"\n", | |
| " .format(engine.state.epoch, metrics['avg_loss'], metrics['avg_accuracy'], metrics['avg_precision'], metrics['avg_recall'])) " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "_uuid": "9dcd01508d6925b11ab77d520955a7c94ef59dc4" | |
| }, | |
| "source": [ | |
| "Let's check it again" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": { | |
| "_uuid": "0acaa445bcabc608b8139a1b8a6b132608b69026", | |
| "collapsed": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch[1] Iteration[50/1980] Loss: 2.6200\n", | |
| "Epoch[1] Iteration[100/1980] Loss: 2.7961\n", | |
| "Epoch[1] Iteration[150/1980] Loss: 2.6174\n", | |
| "Epoch[1] Iteration[200/1980] Loss: 2.5267\n", | |
| "Epoch[1] Iteration[250/1980] Loss: 2.5079\n", | |
| "Epoch[1] Iteration[300/1980] Loss: 2.6094\n", | |
| "Epoch[1] Iteration[350/1980] Loss: 3.4147\n", | |
| "Epoch[1] Iteration[400/1980] Loss: 2.5729\n", | |
| "Epoch[1] Iteration[450/1980] Loss: 2.3938\n", | |
| "Epoch[1] Iteration[500/1980] Loss: 2.4503\n", | |
| "Epoch[1] Iteration[550/1980] Loss: 1.8824\n", | |
| "Epoch[1] Iteration[600/1980] Loss: 2.6311\n", | |
| "Epoch[1] Iteration[650/1980] Loss: 2.6125\n", | |
| "Epoch[1] Iteration[700/1980] Loss: 2.2534\n", | |
| "Epoch[1] Iteration[750/1980] Loss: 2.7049\n", | |
| "Epoch[1] Iteration[800/1980] Loss: 2.5136\n", | |
| "Epoch[1] Iteration[850/1980] Loss: 2.4184\n", | |
| "Epoch[1] Iteration[900/1980] Loss: 2.4033\n", | |
| "Epoch[1] Iteration[950/1980] Loss: 2.4063\n", | |
| "Epoch[1] Iteration[1000/1980] Loss: 2.5372\n", | |
| "Epoch[1] Iteration[1050/1980] Loss: 1.8269\n", | |
| "Epoch[1] Iteration[1100/1980] Loss: 2.1790\n", | |
| "Epoch[1] Iteration[1150/1980] Loss: 2.0992\n", | |
| "Epoch[1] Iteration[1200/1980] Loss: 2.7634\n", | |
| "Epoch[1] Iteration[1250/1980] Loss: 2.3822\n", | |
| "Epoch[1] Iteration[1300/1980] Loss: 2.0993\n", | |
| "Epoch[1] Iteration[1350/1980] Loss: 2.1972\n", | |
| "Epoch[1] Iteration[1400/1980] Loss: 2.2227\n", | |
| "Epoch[1] Iteration[1450/1980] Loss: 1.9222\n", | |
| "Epoch[1] Iteration[1500/1980] Loss: 2.1760\n", | |
| "Epoch[1] Iteration[1550/1980] Loss: 1.8377\n", | |
| "Epoch[1] Iteration[1600/1980] Loss: 2.5207\n", | |
| "Epoch[1] Iteration[1650/1980] Loss: 2.0970\n", | |
| "Epoch[1] Iteration[1700/1980] Loss: 1.5135\n", | |
| "Epoch[1] Iteration[1750/1980] Loss: 2.0837\n", | |
| "Epoch[1] Iteration[1800/1980] Loss: 2.0120\n", | |
| "Epoch[1] Iteration[1850/1980] Loss: 1.8728\n", | |
| "Epoch[1] Iteration[1900/1980] Loss: 1.5219\n", | |
| "Epoch[1] Iteration[1950/1980] Loss: 1.8386\n", | |
| "Compute train metrics...\n", | |
| "Training Results - Epoch: 1 Average Loss: 1.6411 | Accuracy: 0.4146 | Precision: 0.3932 | Recall: 0.4129\n", | |
| "Compute validation metrics...\n", | |
| "Validation Results - Epoch: 1 Average Loss: 1.6547 | Accuracy: 0.4147 | Precision: 0.4161 | Recall: 0.4185\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "output = trainer.run(train_loader, max_epochs=1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "_uuid": "6de81a2fd6a267bdb2e0bbdd44ae94a1dcb92742" | |
| }, | |
| "source": [ | |
| "Nice !\n", | |
| "\n", | |
| "> run offline metrics computation on a subset of the training dataset\n", | |
| "\n", | |
| "> run metrics computation on the validation dataset once epoch is finished\n", | |
| "\n", | |
| "Done !" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "_uuid": "624c285ac6e756f3e264cbed2deb4e9b912977e3" | |
| }, | |
| "source": [ | |
| "----\n", | |
| "\n", | |
| "##### More details\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "Let's explain some details in the above code. Maybe you've remarked the following\n", | |
| "```python\n", | |
| "metrics = train_evaluator.run(train_eval_loader).metrics\n", | |
| "```\n", | |
| "and you have a question what is the object returned by `train_evaluator.run(train_eval_loader)` that has `metrics` as attribute. \n", | |
| "\n", | |
| "Actually, `Engine` contains a structure called `State` to pass data between handlers. Basically, `State` contains information on the current \n", | |
| "epoch, iteration, max epochs, etc and also can be used to pass some custom data, such as metrics. Thus, the above code can be rewritten as \n", | |
| "```python\n", | |
| "state = train_evaluator.run(train_eval_loader)\n", | |
| "metrics = state.metrics\n", | |
| "# or just\n", | |
| "train_evaluator.run(train_eval_loader)\n", | |
| "metrics = train_evaluator.state.metrics\n", | |
| "```\n", | |
| "\n", | |
| "-----" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "_uuid": "4eda5ec09e175404ee59a89706635413ed923570" | |
| }, | |
| "source": [ | |
| "#### Learning rate scheduling\n", | |
| "\n", | |
| "There are several ways to perform learning rate scheduling with *ignite*, here we will use the most simple one by calling `lr_scheduler.step()` every epoch:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": { | |
| "_uuid": "7217a70aa70774308278d651de4c95f38dbb120e", | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from torch.optim.lr_scheduler import ExponentialLR\n", | |
| "\n", | |
| "\n", | |
| "lr_scheduler = ExponentialLR(optimizer, gamma=0.8)\n", | |
| "\n", | |
| "\n", | |
| "@trainer.on(Events.EPOCH_STARTED)\n", | |
| "def update_lr_scheduler(engine):\n", | |
| " lr_scheduler.step()\n", | |
| " # Display learning rate:\n", | |
| " if len(optimizer.param_groups) == 1:\n", | |
| " lr = float(optimizer.param_groups[0]['lr'])\n", | |
| " print(\"Learning rate: {}\".format(lr))\n", | |
| " else:\n", | |
| " for i, param_group in enumerate(optimizer.param_groups):\n", | |
| " lr = float(param_group['lr'])\n", | |
| " print(\"Learning rate (group {}): {}\".format(i, lr)) " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "_uuid": "57b2433b79e1baa345376828bde594e57f736032" | |
| }, | |
| "source": [ | |
| "#### Training checkpointing\n", | |
| "\n", | |
| "As we move on training, we would like to store the best model, last trained model, optimizer and learning rate scheduler. With *ignite* it is not a problem, there is a special class `ModelCheckpoint` for these purposes. \n", | |
| "\n", | |
| "Let's use `ModelCheckpoint` handler to store the best model defined by validation accuracy. In this case we define a `score_function` that provides validation accuracy to the handler and it decides (max value - better) whether to save or not the model." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": { | |
| "_uuid": "2a0ee3f0b7d7387d9914c274607760fc28b19655", | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from ignite.handlers import ModelCheckpoint\n", | |
| "\n", | |
| "\n", | |
| "def score_function(engine):\n", | |
| " val_avg_accuracy = engine.state.metrics['avg_accuracy']\n", | |
| " # Objects with highest scores will be retained.\n", | |
| " return val_avg_accuracy\n", | |
| "\n", | |
| "\n", | |
| "best_model_saver = ModelCheckpoint(\"best_models\", # folder where to save the best model(s)\n", | |
| " filename_prefix=\"model\", # filename prefix -> {filename_prefix}_{name}_{step_number}_{score_name}={abs(score_function_result)}.pth\n", | |
| " score_name=\"val_accuracy\", \n", | |
| " score_function=score_function,\n", | |
| " n_saved=3,\n", | |
| " atomic=True, # objects are saved to a temporary file and then moved to final destination, so that files are guaranteed to not be damaged\n", | |
| " save_as_state_dict=True, # Save object as state_dict\n", | |
| " create_dir=True)\n", | |
| "val_evaluator.add_event_handler(Events.COMPLETED, best_model_saver, {\"best_model\": model})" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "_uuid": "d61be4bfcbba65becf9a56033cdc3e55108d963c" | |
| }, | |
| "source": [ | |
| "Now let's define another `ModelCheckpoint` handler to store trained model, optimizer and lr scheduler every 1000 iterations (100 iterations when on CPU):" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "metadata": { | |
| "_uuid": "790337775b67543f3df264c104ce8a9a2483cdbf", | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "save_interval = 1000 if \"cuda\" in device else 100\n", | |
| "\n", | |
| "\n", | |
| "training_saver = ModelCheckpoint(\"checkpoint\",\n", | |
| " filename_prefix=\"checkpoint\",\n", | |
| " save_interval=save_interval, # Save every 1000 iterations when on GPU and 100 when on CPU\n", | |
| " n_saved=1,\n", | |
| " atomic=True,\n", | |
| " # save_as_state_dict=True,\n", | |
| " create_dir=True)\n", | |
| "\n", | |
| "trainer.add_event_handler(Events.ITERATION_COMPLETED, \n", | |
| " training_saver, \n", | |
| " {\n", | |
| " \"model\": model,\n", | |
| " \"optimizer\": optimizer,\n", | |
| " \"lr_scheduler\": lr_scheduler\n", | |
| " })" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "_uuid": "dbdcb311713b3e3ad05f936d7a9d9525055b6ae6" | |
| }, | |
| "source": [ | |
| "We are almost done with preparations and a cherry on top\n", | |
| "\n", | |
| "#### Early stopping\n", | |
| "\n", | |
| "Let's add another handler to stop training if model fails to improve a score defined by a `score_function` during 10 epochs:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "metadata": { | |
| "_uuid": "60c08fbae224364b4b089d8ab34e4860d57c3d7e", | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from ignite.handlers import EarlyStopping\n", | |
| "\n", | |
| "early_stopping = EarlyStopping(patience=10, score_function=score_function, trainer=trainer)\n", | |
| "\n", | |
| "val_evaluator.add_event_handler(Events.EPOCH_COMPLETED, early_stopping)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "_uuid": "cc3acbcc0f3ed150a978067dfa23358646fa5bb0" | |
| }, | |
| "source": [ | |
| "## Run training\n", | |
| "\n", | |
| "Now we can just call `run` method and train model during a number of epochs (e.g. 5 epochs to let this kernel terminate correctly before timeout)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "metadata": { | |
| "_uuid": "ecd0e705a55519e0e4fe700a4e729bdb5ec5da86", | |
| "collapsed": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Learning rate: 0.01\n", | |
| "Epoch[1] Iteration[50/1980] Loss: 2.0297\n", | |
| "Epoch[1] Iteration[100/1980] Loss: 2.5002\n", | |
| "Epoch[1] Iteration[150/1980] Loss: 1.4305\n", | |
| "Epoch[1] Iteration[200/1980] Loss: 1.9010\n", | |
| "Epoch[1] Iteration[250/1980] Loss: 1.8674\n", | |
| "Epoch[1] Iteration[300/1980] Loss: 1.7257\n", | |
| "Epoch[1] Iteration[350/1980] Loss: 1.7534\n", | |
| "Epoch[1] Iteration[400/1980] Loss: 1.3426\n", | |
| "Epoch[1] Iteration[450/1980] Loss: 1.3853\n", | |
| "Epoch[1] Iteration[500/1980] Loss: 1.8335\n", | |
| "Epoch[1] Iteration[550/1980] Loss: 1.6262\n", | |
| "Epoch[1] Iteration[600/1980] Loss: 2.3171\n", | |
| "Epoch[1] Iteration[650/1980] Loss: 2.4359\n", | |
| "Epoch[1] Iteration[700/1980] Loss: 1.8466\n", | |
| "Epoch[1] Iteration[750/1980] Loss: 1.9876\n", | |
| "Epoch[1] Iteration[800/1980] Loss: 1.6165\n", | |
| "Epoch[1] Iteration[850/1980] Loss: 1.7135\n", | |
| "Epoch[1] Iteration[900/1980] Loss: 1.6624\n", | |
| "Epoch[1] Iteration[950/1980] Loss: 1.9795\n", | |
| "Epoch[1] Iteration[1000/1980] Loss: 0.9423\n", | |
| "Epoch[1] Iteration[1050/1980] Loss: 2.3166\n", | |
| "Epoch[1] Iteration[1100/1980] Loss: 1.6194\n", | |
| "Epoch[1] Iteration[1150/1980] Loss: 2.1837\n", | |
| "Epoch[1] Iteration[1200/1980] Loss: 1.6516\n", | |
| "Epoch[1] Iteration[1250/1980] Loss: 2.0080\n", | |
| "Epoch[1] Iteration[1300/1980] Loss: 1.5937\n", | |
| "Epoch[1] Iteration[1350/1980] Loss: 1.3191\n", | |
| "Epoch[1] Iteration[1400/1980] Loss: 1.8605\n", | |
| "Epoch[1] Iteration[1450/1980] Loss: 1.5833\n", | |
| "Epoch[1] Iteration[1500/1980] Loss: 1.1223\n", | |
| "Epoch[1] Iteration[1550/1980] Loss: 1.2670\n", | |
| "Epoch[1] Iteration[1600/1980] Loss: 2.1860\n", | |
| "Epoch[1] Iteration[1650/1980] Loss: 1.5690\n", | |
| "Epoch[1] Iteration[1700/1980] Loss: 1.0363\n", | |
| "Epoch[1] Iteration[1750/1980] Loss: 1.7205\n", | |
| "Epoch[1] Iteration[1800/1980] Loss: 1.1404\n", | |
| "Epoch[1] Iteration[1850/1980] Loss: 1.1508\n", | |
| "Epoch[1] Iteration[1900/1980] Loss: 1.3621\n", | |
| "Epoch[1] Iteration[1950/1980] Loss: 1.7613\n", | |
| "Compute train metrics...\n", | |
| "Training Results - Epoch: 1 Average Loss: 1.1558 | Accuracy: 0.5918 | Precision: 0.5966 | Recall: 0.5990\n", | |
| "Compute validation metrics...\n", | |
| "Validation Results - Epoch: 1 Average Loss: 1.1876 | Accuracy: 0.5929 | Precision: 0.5788 | Recall: 0.5949\n", | |
| "Learning rate: 0.008\n", | |
| "Epoch[2] Iteration[50/1980] Loss: 1.2283\n", | |
| "Epoch[2] Iteration[100/1980] Loss: 1.0316\n", | |
| "Epoch[2] Iteration[150/1980] Loss: 0.4446\n", | |
| "Epoch[2] Iteration[200/1980] Loss: 0.8731\n", | |
| "Epoch[2] Iteration[250/1980] Loss: 1.0064\n", | |
| "Epoch[2] Iteration[300/1980] Loss: 1.4528\n", | |
| "Epoch[2] Iteration[350/1980] Loss: 1.2250\n", | |
| "Epoch[2] Iteration[400/1980] Loss: 1.0561\n", | |
| "Epoch[2] Iteration[450/1980] Loss: 1.2613\n", | |
| "Epoch[2] Iteration[500/1980] Loss: 0.9631\n", | |
| "Epoch[2] Iteration[550/1980] Loss: 1.6906\n", | |
| "Epoch[2] Iteration[600/1980] Loss: 1.5475\n", | |
| "Epoch[2] Iteration[650/1980] Loss: 0.6657\n", | |
| "Epoch[2] Iteration[700/1980] Loss: 1.0129\n", | |
| "Epoch[2] Iteration[750/1980] Loss: 0.8725\n", | |
| "Epoch[2] Iteration[800/1980] Loss: 0.8801\n", | |
| "Epoch[2] Iteration[850/1980] Loss: 1.0811\n", | |
| "Epoch[2] Iteration[900/1980] Loss: 1.5743\n", | |
| "Epoch[2] Iteration[950/1980] Loss: 0.8041\n", | |
| "Epoch[2] Iteration[1000/1980] Loss: 1.2757\n", | |
| "Epoch[2] Iteration[1050/1980] Loss: 1.3404\n", | |
| "Epoch[2] Iteration[1100/1980] Loss: 2.0413\n", | |
| "Epoch[2] Iteration[1150/1980] Loss: 1.1338\n", | |
| "Epoch[2] Iteration[1200/1980] Loss: 1.0582\n", | |
| "Epoch[2] Iteration[1250/1980] Loss: 0.6886\n", | |
| "Epoch[2] Iteration[1300/1980] Loss: 1.2802\n", | |
| "Epoch[2] Iteration[1350/1980] Loss: 0.8355\n", | |
| "Epoch[2] Iteration[1400/1980] Loss: 1.1353\n", | |
| "Epoch[2] Iteration[1450/1980] Loss: 0.9479\n", | |
| "Epoch[2] Iteration[1500/1980] Loss: 1.0295\n", | |
| "Epoch[2] Iteration[1550/1980] Loss: 0.7993\n", | |
| "Epoch[2] Iteration[1600/1980] Loss: 0.9488\n", | |
| "Epoch[2] Iteration[1650/1980] Loss: 0.9145\n", | |
| "Epoch[2] Iteration[1700/1980] Loss: 1.1164\n", | |
| "Epoch[2] Iteration[1750/1980] Loss: 0.8432\n", | |
| "Epoch[2] Iteration[1800/1980] Loss: 1.1087\n", | |
| "Epoch[2] Iteration[1850/1980] Loss: 0.2490\n", | |
| "Epoch[2] Iteration[1900/1980] Loss: 0.5431\n", | |
| "Epoch[2] Iteration[1950/1980] Loss: 1.6354\n", | |
| "Compute train metrics...\n", | |
| "Training Results - Epoch: 2 Average Loss: 0.9183 | Accuracy: 0.6916 | Precision: 0.7316 | Recall: 0.6955\n", | |
| "Compute validation metrics...\n", | |
| "Validation Results - Epoch: 2 Average Loss: 0.9143 | Accuracy: 0.6892 | Precision: 0.7399 | Recall: 0.6934\n", | |
| "Learning rate: 0.006400000000000001\n", | |
| "Epoch[3] Iteration[50/1980] Loss: 0.5084\n", | |
| "Epoch[3] Iteration[100/1980] Loss: 0.5710\n", | |
| "Epoch[3] Iteration[150/1980] Loss: 0.6886\n", | |
| "Epoch[3] Iteration[200/1980] Loss: 0.8733\n", | |
| "Epoch[3] Iteration[250/1980] Loss: 0.5331\n", | |
| "Epoch[3] Iteration[300/1980] Loss: 1.4217\n", | |
| "Epoch[3] Iteration[350/1980] Loss: 1.3065\n", | |
| "Epoch[3] Iteration[400/1980] Loss: 0.5387\n", | |
| "Epoch[3] Iteration[450/1980] Loss: 0.6784\n", | |
| "Epoch[3] Iteration[500/1980] Loss: 0.5305\n", | |
| "Epoch[3] Iteration[550/1980] Loss: 0.3985\n", | |
| "Epoch[3] Iteration[600/1980] Loss: 0.7405\n", | |
| "Epoch[3] Iteration[650/1980] Loss: 1.5790\n", | |
| "Epoch[3] Iteration[700/1980] Loss: 0.4117\n", | |
| "Epoch[3] Iteration[750/1980] Loss: 1.2748\n", | |
| "Epoch[3] Iteration[800/1980] Loss: 0.2740\n", | |
| "Epoch[3] Iteration[850/1980] Loss: 0.5408\n", | |
| "Epoch[3] Iteration[900/1980] Loss: 2.1083\n", | |
| "Epoch[3] Iteration[950/1980] Loss: 1.2276\n", | |
| "Epoch[3] Iteration[1000/1980] Loss: 0.8733\n", | |
| "Epoch[3] Iteration[1050/1980] Loss: 0.4601\n", | |
| "Epoch[3] Iteration[1100/1980] Loss: 0.7347\n", | |
| "Epoch[3] Iteration[1150/1980] Loss: 0.8986\n", | |
| "Epoch[3] Iteration[1200/1980] Loss: 1.3997\n", | |
| "Epoch[3] Iteration[1250/1980] Loss: 0.5788\n", | |
| "Epoch[3] Iteration[1300/1980] Loss: 0.4942\n", | |
| "Epoch[3] Iteration[1350/1980] Loss: 0.4114\n", | |
| "Epoch[3] Iteration[1400/1980] Loss: 1.3255\n", | |
| "Epoch[3] Iteration[1450/1980] Loss: 1.1549\n", | |
| "Epoch[3] Iteration[1500/1980] Loss: 1.0794\n", | |
| "Epoch[3] Iteration[1550/1980] Loss: 1.0410\n", | |
| "Epoch[3] Iteration[1600/1980] Loss: 0.6442\n", | |
| "Epoch[3] Iteration[1650/1980] Loss: 1.2836\n", | |
| "Epoch[3] Iteration[1700/1980] Loss: 0.8266\n", | |
| "Epoch[3] Iteration[1750/1980] Loss: 0.7842\n", | |
| "Epoch[3] Iteration[1800/1980] Loss: 1.0421\n", | |
| "Epoch[3] Iteration[1850/1980] Loss: 0.7566\n", | |
| "Epoch[3] Iteration[1900/1980] Loss: 1.8921\n", | |
| "Epoch[3] Iteration[1950/1980] Loss: 0.4822\n", | |
| "Compute train metrics...\n", | |
| "Training Results - Epoch: 3 Average Loss: 0.7132 | Accuracy: 0.7624 | Precision: 0.7997 | Recall: 0.7678\n", | |
| "Compute validation metrics...\n", | |
| "Validation Results - Epoch: 3 Average Loss: 0.7172 | Accuracy: 0.7689 | Precision: 0.8002 | Recall: 0.7717\n", | |
| "Learning rate: 0.005120000000000001\n", | |
| "Epoch[4] Iteration[50/1980] Loss: 0.3512\n", | |
| "Epoch[4] Iteration[100/1980] Loss: 0.7987\n", | |
| "Epoch[4] Iteration[150/1980] Loss: 0.2809\n", | |
| "Epoch[4] Iteration[200/1980] Loss: 0.4992\n", | |
| "Epoch[4] Iteration[250/1980] Loss: 0.8425\n", | |
| "Epoch[4] Iteration[300/1980] Loss: 0.8953\n", | |
| "Epoch[4] Iteration[350/1980] Loss: 0.7096\n", | |
| "Epoch[4] Iteration[400/1980] Loss: 0.4830\n", | |
| "Epoch[4] Iteration[450/1980] Loss: 0.6284\n", | |
| "Epoch[4] Iteration[500/1980] Loss: 0.3292\n", | |
| "Epoch[4] Iteration[550/1980] Loss: 0.2824\n", | |
| "Epoch[4] Iteration[600/1980] Loss: 0.7042\n", | |
| "Epoch[4] Iteration[650/1980] Loss: 0.4519\n", | |
| "Epoch[4] Iteration[700/1980] Loss: 0.4058\n", | |
| "Epoch[4] Iteration[750/1980] Loss: 0.9855\n", | |
| "Epoch[4] Iteration[800/1980] Loss: 0.6634\n", | |
| "Epoch[4] Iteration[850/1980] Loss: 0.5028\n", | |
| "Epoch[4] Iteration[900/1980] Loss: 0.4286\n", | |
| "Epoch[4] Iteration[950/1980] Loss: 0.3905\n", | |
| "Epoch[4] Iteration[1000/1980] Loss: 0.4156\n", | |
| "Epoch[4] Iteration[1050/1980] Loss: 0.5205\n", | |
| "Epoch[4] Iteration[1100/1980] Loss: 0.2619\n", | |
| "Epoch[4] Iteration[1150/1980] Loss: 0.5113\n", | |
| "Epoch[4] Iteration[1200/1980] Loss: 0.4677\n", | |
| "Epoch[4] Iteration[1250/1980] Loss: 0.7906\n", | |
| "Epoch[4] Iteration[1300/1980] Loss: 0.5569\n", | |
| "Epoch[4] Iteration[1350/1980] Loss: 0.1790\n", | |
| "Epoch[4] Iteration[1400/1980] Loss: 0.3694\n", | |
| "Epoch[4] Iteration[1450/1980] Loss: 0.4852\n", | |
| "Epoch[4] Iteration[1500/1980] Loss: 0.5866\n", | |
| "Epoch[4] Iteration[1550/1980] Loss: 0.2350\n", | |
| "Epoch[4] Iteration[1600/1980] Loss: 0.7430\n", | |
| "Epoch[4] Iteration[1650/1980] Loss: 0.9267\n", | |
| "Epoch[4] Iteration[1700/1980] Loss: 0.2773\n", | |
| "Epoch[4] Iteration[1750/1980] Loss: 0.4097\n", | |
| "Epoch[4] Iteration[1800/1980] Loss: 0.5146\n", | |
| "Epoch[4] Iteration[1850/1980] Loss: 0.5009\n", | |
| "Epoch[4] Iteration[1900/1980] Loss: 0.6576\n", | |
| "Epoch[4] Iteration[1950/1980] Loss: 0.5659\n", | |
| "Compute train metrics...\n", | |
| "Training Results - Epoch: 4 Average Loss: 0.4388 | Accuracy: 0.8529 | Precision: 0.8623 | Recall: 0.8573\n", | |
| "Compute validation metrics...\n", | |
| "Validation Results - Epoch: 4 Average Loss: 0.4554 | Accuracy: 0.8513 | Precision: 0.8619 | Recall: 0.8545\n", | |
| "Learning rate: 0.004096000000000001\n", | |
| "Epoch[5] Iteration[50/1980] Loss: 0.9690\n", | |
| "Epoch[5] Iteration[100/1980] Loss: 0.5726\n", | |
| "Epoch[5] Iteration[150/1980] Loss: 0.3923\n", | |
| "Epoch[5] Iteration[200/1980] Loss: 0.4984\n", | |
| "Epoch[5] Iteration[250/1980] Loss: 0.2878\n", | |
| "Epoch[5] Iteration[300/1980] Loss: 0.5360\n", | |
| "Epoch[5] Iteration[350/1980] Loss: 0.1766\n", | |
| "Epoch[5] Iteration[400/1980] Loss: 0.6906\n", | |
| "Epoch[5] Iteration[450/1980] Loss: 0.3585\n", | |
| "Epoch[5] Iteration[500/1980] Loss: 0.2153\n", | |
| "Epoch[5] Iteration[550/1980] Loss: 0.3440\n", | |
| "Epoch[5] Iteration[600/1980] Loss: 0.6730\n", | |
| "Epoch[5] Iteration[650/1980] Loss: 0.5256\n", | |
| "Epoch[5] Iteration[700/1980] Loss: 0.1477\n", | |
| "Epoch[5] Iteration[750/1980] Loss: 0.4275\n", | |
| "Epoch[5] Iteration[800/1980] Loss: 0.4505\n", | |
| "Epoch[5] Iteration[850/1980] Loss: 0.3685\n", | |
| "Epoch[5] Iteration[900/1980] Loss: 0.2548\n", | |
| "Epoch[5] Iteration[950/1980] Loss: 0.2414\n", | |
| "Epoch[5] Iteration[1000/1980] Loss: 0.3050\n", | |
| "Epoch[5] Iteration[1050/1980] Loss: 0.2394\n", | |
| "Epoch[5] Iteration[1100/1980] Loss: 0.4596\n", | |
| "Epoch[5] Iteration[1150/1980] Loss: 0.1614\n", | |
| "Epoch[5] Iteration[1200/1980] Loss: 0.2840\n", | |
| "Epoch[5] Iteration[1250/1980] Loss: 0.5737\n", | |
| "Epoch[5] Iteration[1300/1980] Loss: 0.3448\n", | |
| "Epoch[5] Iteration[1350/1980] Loss: 0.6732\n", | |
| "Epoch[5] Iteration[1400/1980] Loss: 0.1256\n", | |
| "Epoch[5] Iteration[1450/1980] Loss: 0.1550\n", | |
| "Epoch[5] Iteration[1500/1980] Loss: 0.5821\n", | |
| "Epoch[5] Iteration[1550/1980] Loss: 0.0865\n", | |
| "Epoch[5] Iteration[1600/1980] Loss: 0.5322\n", | |
| "Epoch[5] Iteration[1650/1980] Loss: 0.2451\n", | |
| "Epoch[5] Iteration[1700/1980] Loss: 0.6157\n", | |
| "Epoch[5] Iteration[1750/1980] Loss: 0.3351\n", | |
| "Epoch[5] Iteration[1800/1980] Loss: 0.2719\n", | |
| "Epoch[5] Iteration[1850/1980] Loss: 0.2660\n", | |
| "Epoch[5] Iteration[1900/1980] Loss: 0.2685\n", | |
| "Epoch[5] Iteration[1950/1980] Loss: 0.5849\n", | |
| "Compute train metrics...\n", | |
| "Training Results - Epoch: 5 Average Loss: 0.4280 | Accuracy: 0.8557 | Precision: 0.8689 | Recall: 0.8597\n", | |
| "Compute validation metrics...\n", | |
| "Validation Results - Epoch: 5 Average Loss: 0.4325 | Accuracy: 0.8609 | Precision: 0.8732 | Recall: 0.8645\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "output = trainer.run(train_loader, max_epochs=5)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Let's check saved 3 best models and the checkpoint:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 24, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "model_best_model_3_val_accuracy=0.7688843.pth\n", | |
| "model_best_model_4_val_accuracy=0.8512715.pth\n", | |
| "model_best_model_5_val_accuracy=0.8609365.pth\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "!ls best_models/" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 25, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "checkpoint_lr_scheduler_9000.pth checkpoint_optimizer_9000.pth\n", | |
| "checkpoint_model_9000.pth\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "!ls checkpoint/" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Inference\n", | |
| "\n", | |
| "Let's first create a test dataloader from validation dataset such that provided batch is composed of `(samples, sample_indices)`:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 26, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class TestDataset(Dataset):\n", | |
| " \n", | |
| " def __init__(self, ds):\n", | |
| " self.ds = ds\n", | |
| " \n", | |
| " def __len__(self):\n", | |
| " return len(self.ds)\n", | |
| " \n", | |
| " def __getitem__(self, index):\n", | |
| " return self.ds[index][0], index\n", | |
| "\n", | |
| " \n", | |
| "test_dataset = TestDataset(val_dataset)\n", | |
| "\n", | |
| "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, \n", | |
| " drop_last=False, pin_memory=\"cuda\" in device)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "With ignite to implement an engine that inference on data is simple. Similarly when we created an evaluation engine, now we modify the update function to store output results. We will also perform what is called test time augmentation (TTA)." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 27, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch.nn.functional as F\n", | |
| "from ignite._utils import convert_tensor\n", | |
| "\n", | |
| "\n", | |
| "def _prepare_batch(batch):\n", | |
| " x, index = batch\n", | |
| " x = convert_tensor(x, device=device)\n", | |
| " return x, index\n", | |
| "\n", | |
| "\n", | |
| "def inference_update(engine, batch):\n", | |
| " x, indices = _prepare_batch(batch)\n", | |
| " y_pred = model(x)\n", | |
| " y_pred = F.softmax(y_pred, dim=1)\n", | |
| " return {\n", | |
| " \"y_pred\": convert_tensor(y_pred, device='cpu'),\n", | |
| " \"indices\": indices\n", | |
| " }\n", | |
| "\n", | |
| " \n", | |
| "model.eval()\n", | |
| "inferencer = Engine(inference_update) " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Next let's define a handler to log steps during the inference and a handler to store predictions" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 28, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "@inferencer.on(Events.EPOCH_COMPLETED)\n", | |
| "def log_tta(engine):\n", | |
| " print(\"TTA {} / {}\".format(engine.state.epoch, n_tta))\n", | |
| "\n", | |
| " \n", | |
| "n_tta = 3\n", | |
| "num_classes = 64\n", | |
| "n_samples = len(val_dataset)\n", | |
| "\n", | |
| "# Array to store prediction probabilities\n", | |
| "y_probas_tta = np.zeros((n_samples, num_classes, n_tta), dtype=np.float32)\n", | |
| "\n", | |
| "# Array to store sample indices\n", | |
| "indices = np.zeros((n_samples, ), dtype=np.int)\n", | |
| " \n", | |
| "\n", | |
| "@inferencer.on(Events.ITERATION_COMPLETED)\n", | |
| "def save_results(engine):\n", | |
| " output = engine.state.output\n", | |
| " tta_index = engine.state.epoch - 1\n", | |
| " start_index = ((engine.state.iteration - 1) % len(test_loader)) * batch_size\n", | |
| " end_index = min(start_index + batch_size, n_samples)\n", | |
| " batch_y_probas = output['y_pred'].detach().numpy()\n", | |
| " y_probas_tta[start_index:end_index, :, tta_index] = batch_y_probas\n", | |
| " if tta_index == 0:\n", | |
| " indices[start_index:end_index] = output['indices']" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Before running the inference, we may want to load the best model from the storage:\n", | |
| "```python\n", | |
| "model = squeezenet1_1(pretrained=False, num_classes=64)\n", | |
| "model.classifier[-1] = nn.AdaptiveAvgPool2d(1) # Adapt the last average pooling to our data\n", | |
| "model = model.to(device)\n", | |
| "\n", | |
| "model_state_dict = torch.load(\"best_models/model_best_model_N_val_accuracy=0.XYZ.pth\")\n", | |
| "model.load_state_dict(model_state_dict)\n", | |
| "```" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 29, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "TTA 1 / 3\n", | |
| "TTA 2 / 3\n", | |
| "TTA 3 / 3\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<ignite.engine.engine.State at 0x7fb709947cf8>" | |
| ] | |
| }, | |
| "execution_count": 29, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "inferencer.run(test_loader, max_epochs=n_tta)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Final probability aggregation can be done using mean or gmean" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 30, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "y_probas = np.mean(y_probas_tta, axis=-1)\n", | |
| "y_preds = np.argmax(y_probas, axis=-1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Next step can be to create a submission using `indices` and `y_probas`. Here we will just compute accuracy on our test=validation dataset" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 31, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from sklearn.metrics import accuracy_score\n", | |
| "\n", | |
| "y_test_true = [y for _, y in val_dataset]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 32, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.9044759313127522" | |
| ] | |
| }, | |
| "execution_count": 32, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "accuracy_score(y_test_true, y_preds)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "_uuid": "9ee5b8380dd53ce5afef93cf6df6c80ead3ced1e" | |
| }, | |
| "source": [ | |
| "### Final words\n", | |
| "\n", | |
| "That's all for this kernel. If you liked it - please upvote. \n", | |
| "\n", | |
| "If you liked *ignite*, please visit its [documentation site](https://pytorch.org/ignite/), [github code](https://github.com/pytorch/ignite) and checkout [examples](https://github.com/pytorch/ignite/tree/master/examples) with `tensorboard`, `visdom` integration and how to train dcgan. Some other examples can be found [here](https://github.com/vfdev-5/ignite-examples).\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.5.2" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment