Created
September 11, 2019 09:18
-
-
Save BeMg/7ea5208da9fd7bb0ecd7d36293426cfe to your computer and use it in GitHub Desktop.
aaaa
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Validation of Gluon Models on ImageNet\n", | |
"\n", | |
"We are about to test the performance of gluon pretrained models in the [model zoo](https://github.com/apache/incubator-mxnet/tree/master/python/mxnet/gluon/model_zoo/vision) on ImageNet validation dataset. The validation dataset consists of 50,000 images, 50 per class.\n", | |
"\n", | |
"To begin with, import the necessary modules." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"attributes": { | |
"classes": [], | |
"id": "", | |
"n": "1" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"import mxnet as mx\n", | |
"import numpy as np\n", | |
"from mxnet import gluon, nd\n", | |
"from mxnet import autograd as ag\n", | |
"from mxnet.gluon import nn\n", | |
"from mxnet.gluon.model_zoo import vision as models\n", | |
"from mxnet.gluon.data import vision\n", | |
"from mxnet.gluon.data.vision import transforms\n", | |
"\n", | |
"import shutil\n", | |
"import time\n", | |
"import os\n", | |
"import pandas as pd" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Data Download\n", | |
"\n", | |
"First we need to prepare the [validation images](http://www.image-net.org/download-images) and preprocess them. Due to the black magic of Gluon, we do not need to transform the JPEG files to `rec` format.\n", | |
"\n", | |
"The DataLoader in Gluon requires the images arranged in the subdir named after the image label. However, the validation images are stored in one folder, so we need to create a subdir for every label, and put a symbol link to the image in the every subdir." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"attributes": { | |
"classes": [], | |
"id": "", | |
"n": "2" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"val_path = 'val'\n", | |
"image_path = '/unsullied/sharefs/luojing/exp/Dataset/ImageNet2012Val/ILSVRC2012_img_val' #use absolute address\n", | |
"synsets_file = open('../basics/synsets.txt', 'r')\n", | |
"val_file = open('../basics/val.txt', 'r')\n", | |
"\n", | |
"if os.path.exists(val_path):\n", | |
" shutil.rmtree(val_path)\n", | |
"\n", | |
"synsets = [line.rstrip('\\n') for line in synsets_file.readlines()]\n", | |
"\n", | |
"for line in val_file.readlines():\n", | |
" fname, idx = line.split()\n", | |
" label_path = '%s/%s' % (val_path, synsets[int(idx)])\n", | |
" if not os.path.exists(label_path):\n", | |
" os.makedirs(label_path)\n", | |
" os.symlink('%s/%s' % (image_path, fname), '%s/%s' % (label_path, fname))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The common benchmark of ImageNet validation is Top 1/5 error of `224x` on `256x` Crop. That is how we define the tranform function." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"attributes": { | |
"classes": [], | |
"id": "", | |
"n": "3" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"transform_test = transforms.Compose([\n", | |
" transforms.Resize(256),\n", | |
" transforms.CenterCrop(224),\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", | |
"])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We use `ImageFolderDataset` to read the validation dataset." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"attributes": { | |
"classes": [], | |
"id": "", | |
"n": "4" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"batch_size = 25\n", | |
"num_gpus = 8\n", | |
"batch_size *= max(1, num_gpus)\n", | |
"val_total = 50000\n", | |
"ctx = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]\n", | |
"num_workers = 10\n", | |
"\n", | |
"val_dataset = vision.ImageFolderDataset(val_path)\n", | |
"val_data = gluon.data.DataLoader(\n", | |
" val_dataset.transform_first(transform_test),\n", | |
" batch_size=batch_size, shuffle=False, num_workers=num_workers)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Load Models and Test" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"attributes": { | |
"classes": [], | |
"id": "", | |
"n": "5" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Model: alexnet \t Top 1 Err: 0.470920 \t Top 5 Err: 0.233480 \n", | |
"Model: densenet121 \t Top 1 Err: 0.254400 \t Top 5 Err: 0.078040 \n", | |
"Model: densenet161 \t Top 1 Err: 0.223880 \t Top 5 Err: 0.060680 \n", | |
"Model: densenet169 \t Top 1 Err: 0.238380 \t Top 5 Err: 0.068240 \n", | |
"Model: densenet201 \t Top 1 Err: 0.231000 \t Top 5 Err: 0.065260 \n", | |
"Model: mobilenet0.25 \t Top 1 Err: 0.499980 \t Top 5 Err: 0.255280 \n", | |
"Model: mobilenet0.5 \t Top 1 Err: 0.380900 \t Top 5 Err: 0.159200 \n", | |
"Model: mobilenet0.75 \t Top 1 Err: 0.334320 \t Top 5 Err: 0.127500 \n", | |
"Model: mobilenet1.0 \t Top 1 Err: 0.297760 \t Top 5 Err: 0.103600 \n", | |
"Model: resnet101_v1 \t Top 1 Err: 0.233680 \t Top 5 Err: 0.066580 \n", | |
"Model: resnet101_v2 \t Top 1 Err: 0.227520 \t Top 5 Err: 0.063480 \n", | |
"Model: resnet152_v1 \t Top 1 Err: 0.230380 \t Top 5 Err: 0.064820 \n", | |
"Model: resnet152_v2 \t Top 1 Err: 0.219760 \t Top 5 Err: 0.060840 \n", | |
"Model: resnet18_v1 \t Top 1 Err: 0.336340 \t Top 5 Err: 0.126860 \n", | |
"Model: resnet18_v2 \t Top 1 Err: 0.310500 \t Top 5 Err: 0.113540 \n", | |
"Model: resnet34_v1 \t Top 1 Err: 0.294600 \t Top 5 Err: 0.098060 \n", | |
"Model: resnet34_v2 \t Top 1 Err: 0.274400 \t Top 5 Err: 0.089000 \n", | |
"Model: resnet50_v1 \t Top 1 Err: 0.249180 \t Top 5 Err: 0.074800 \n", | |
"Model: resnet50_v2 \t Top 1 Err: 0.240360 \t Top 5 Err: 0.071920 \n", | |
"Model: squeezenet1.0 \t Top 1 Err: 0.459460 \t Top 5 Err: 0.226700 \n", | |
"Model: squeezenet1.1 \t Top 1 Err: 0.477840 \t Top 5 Err: 0.237060 \n", | |
"Model: vgg11 \t Top 1 Err: 0.348600 \t Top 5 Err: 0.135040 \n", | |
"Model: vgg11_bn \t Top 1 Err: 0.327040 \t Top 5 Err: 0.121760 \n", | |
"Model: vgg13 \t Top 1 Err: 0.337020 \t Top 5 Err: 0.127020 \n", | |
"Model: vgg13_bn \t Top 1 Err: 0.328260 \t Top 5 Err: 0.120140 \n", | |
"Model: vgg16 \t Top 1 Err: 0.315920 \t Top 5 Err: 0.113100 \n", | |
"Model: vgg16_bn \t Top 1 Err: 0.297100 \t Top 5 Err: 0.099320 \n", | |
"Model: vgg19 \t Top 1 Err: 0.305620 \t Top 5 Err: 0.107440 \n", | |
"Model: vgg19_bn \t Top 1 Err: 0.289160 \t Top 5 Err: 0.096620 \n" | |
] | |
} | |
], | |
"source": [ | |
"from mxnet.gluon.model_zoo.model_store import _model_sha1\n", | |
"\n", | |
"test_result = []\n", | |
"acc_top1 = mx.metric.Accuracy()\n", | |
"acc_top5 = mx.metric.TopKAccuracy(5)\n", | |
"\n", | |
"for model in sorted(_model_sha1.keys()):\n", | |
" if model == 'inceptionv3':\n", | |
" continue\n", | |
" net = models.get_model(model, pretrained=True, ctx=ctx)\n", | |
" acc_top1.reset()\n", | |
" acc_top5.reset() \n", | |
" for _, batch in enumerate(val_data): \n", | |
" data = gluon.utils.split_and_load(batch[0], ctx)\n", | |
" label = gluon.utils.split_and_load(batch[1], ctx) \n", | |
" outputs = [net(X) for X in data]\n", | |
" acc_top1.update(label, outputs)\n", | |
" acc_top5.update(label, outputs)\n", | |
" # print_str = 'Top 1 Err: %4f \\t Top 5 Err: %4f '%(1 - top1, 1 - top5)\n", | |
" # pbar.set_description(\"%s\" % print_str)\n", | |
" _, top1 = acc_top1.get()\n", | |
" _, top5 = acc_top5.get()\n", | |
" print('Model: %s \\t Top 1 Err: %4f \\t Top 5 Err: %4f '%(model, 1 - top1, 1 - top5))\n", | |
" test_result.append((model, 1 - top1, 1 - top5)) \n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"attributes": { | |
"classes": [], | |
"id": "", | |
"n": "6" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>model</th>\n", | |
" <th>top_1_err</th>\n", | |
" <th>top_5_err</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>12</th>\n", | |
" <td>resnet152_v2</td>\n", | |
" <td>0.21976</td>\n", | |
" <td>0.06084</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>densenet161</td>\n", | |
" <td>0.22388</td>\n", | |
" <td>0.06068</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>10</th>\n", | |
" <td>resnet101_v2</td>\n", | |
" <td>0.22752</td>\n", | |
" <td>0.06348</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>11</th>\n", | |
" <td>resnet152_v1</td>\n", | |
" <td>0.23038</td>\n", | |
" <td>0.06482</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>densenet201</td>\n", | |
" <td>0.23100</td>\n", | |
" <td>0.06526</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" model top_1_err top_5_err\n", | |
"12 resnet152_v2 0.21976 0.06084\n", | |
"2 densenet161 0.22388 0.06068\n", | |
"10 resnet101_v2 0.22752 0.06348\n", | |
"11 resnet152_v1 0.23038 0.06482\n", | |
"4 densenet201 0.23100 0.06526" | |
] | |
}, | |
"execution_count": 43, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"summary = pd.DataFrame(test_result, columns=['model', 'top_1_err', 'top_5_err'])\n", | |
"summary = summary.sort_values('top_1_err')\n", | |
"summary.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"attributes": { | |
"classes": [], | |
"id": "", | |
"n": "7" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"| resnet152_v2 | 21.98 | 6.08 |\n", | |
"| densenet161 | 22.39 | 6.07 |\n", | |
"| resnet101_v2 | 22.75 | 6.35 |\n", | |
"| resnet152_v1 | 23.04 | 6.48 |\n", | |
"| densenet201 | 23.10 | 6.53 |\n", | |
"| resnet101_v1 | 23.37 | 6.66 |\n", | |
"| densenet169 | 23.84 | 6.82 |\n", | |
"| resnet50_v2 | 24.04 | 7.19 |\n", | |
"| resnet50_v1 | 24.92 | 7.48 |\n", | |
"| densenet121 | 25.44 | 7.80 |\n", | |
"| resnet34_v2 | 27.44 | 8.90 |\n", | |
"| vgg19_bn | 28.92 | 9.66 |\n", | |
"| resnet34_v1 | 29.46 | 9.81 |\n", | |
"| vgg16_bn | 29.71 | 9.93 |\n", | |
"| mobilenet1.0 | 29.78 | 10.36 |\n", | |
"| vgg19 | 30.56 | 10.74 |\n", | |
"| resnet18_v2 | 31.05 | 11.35 |\n", | |
"| vgg16 | 31.59 | 11.31 |\n", | |
"| vgg11_bn | 32.70 | 12.18 |\n", | |
"| vgg13_bn | 32.83 | 12.01 |\n", | |
"| mobilenet0.75 | 33.43 | 12.75 |\n", | |
"| resnet18_v1 | 33.63 | 12.69 |\n", | |
"| vgg13 | 33.70 | 12.70 |\n", | |
"| vgg11 | 34.86 | 13.50 |\n", | |
"| mobilenet0.5 | 38.09 | 15.92 |\n", | |
"| squeezenet1.0 | 45.95 | 22.67 |\n", | |
"| alexnet | 47.09 | 23.35 |\n", | |
"| squeezenet1.1 | 47.78 | 23.71 |\n", | |
"| mobilenet0.25 | 50.00 | 25.53 |\n" | |
] | |
} | |
], | |
"source": [ | |
"for i, (model, top1, top5) in summary.iterrows():\n", | |
" print('| %s | %.2f | %.2f |'%(model, top1 * 100, top5 * 100))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"attributes": { | |
"classes": [], | |
"id": "", | |
"n": "8" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"summary.to_csv('model_zoo_test_result.csv', index=None)" | |
] | |
} | |
], | |
"metadata": {}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment