Skip to content

Instantly share code, notes, and snippets.

@BeMg
Created September 11, 2019 09:18
Show Gist options
  • Save BeMg/7ea5208da9fd7bb0ecd7d36293426cfe to your computer and use it in GitHub Desktop.
Save BeMg/7ea5208da9fd7bb0ecd7d36293426cfe to your computer and use it in GitHub Desktop.
aaaa
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Validation of Gluon Models on ImageNet\n",
"\n",
"We are about to test the performance of gluon pretrained models in the [model zoo](https://github.com/apache/incubator-mxnet/tree/master/python/mxnet/gluon/model_zoo/vision) on ImageNet validation dataset. The validation dataset consists of 50,000 images, 50 per class.\n",
"\n",
"To begin with, import the necessary modules."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "1"
}
},
"outputs": [],
"source": [
"import mxnet as mx\n",
"import numpy as np\n",
"from mxnet import gluon, nd\n",
"from mxnet import autograd as ag\n",
"from mxnet.gluon import nn\n",
"from mxnet.gluon.model_zoo import vision as models\n",
"from mxnet.gluon.data import vision\n",
"from mxnet.gluon.data.vision import transforms\n",
"\n",
"import shutil\n",
"import time\n",
"import os\n",
"import pandas as pd"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data Download\n",
"\n",
"First we need to prepare the [validation images](http://www.image-net.org/download-images) and preprocess them. Due to the black magic of Gluon, we do not need to transform the JPEG files to `rec` format.\n",
"\n",
"The DataLoader in Gluon requires the images arranged in the subdir named after the image label. However, the validation images are stored in one folder, so we need to create a subdir for every label, and put a symbol link to the image in the every subdir."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "2"
}
},
"outputs": [],
"source": [
"val_path = 'val'\n",
"image_path = '/unsullied/sharefs/luojing/exp/Dataset/ImageNet2012Val/ILSVRC2012_img_val' #use absolute address\n",
"synsets_file = open('../basics/synsets.txt', 'r')\n",
"val_file = open('../basics/val.txt', 'r')\n",
"\n",
"if os.path.exists(val_path):\n",
" shutil.rmtree(val_path)\n",
"\n",
"synsets = [line.rstrip('\\n') for line in synsets_file.readlines()]\n",
"\n",
"for line in val_file.readlines():\n",
" fname, idx = line.split()\n",
" label_path = '%s/%s' % (val_path, synsets[int(idx)])\n",
" if not os.path.exists(label_path):\n",
" os.makedirs(label_path)\n",
" os.symlink('%s/%s' % (image_path, fname), '%s/%s' % (label_path, fname))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The common benchmark of ImageNet validation is Top 1/5 error of `224x` on `256x` Crop. That is how we define the tranform function."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "3"
}
},
"outputs": [],
"source": [
"transform_test = transforms.Compose([\n",
" transforms.Resize(256),\n",
" transforms.CenterCrop(224),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We use `ImageFolderDataset` to read the validation dataset."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "4"
}
},
"outputs": [],
"source": [
"batch_size = 25\n",
"num_gpus = 8\n",
"batch_size *= max(1, num_gpus)\n",
"val_total = 50000\n",
"ctx = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]\n",
"num_workers = 10\n",
"\n",
"val_dataset = vision.ImageFolderDataset(val_path)\n",
"val_data = gluon.data.DataLoader(\n",
" val_dataset.transform_first(transform_test),\n",
" batch_size=batch_size, shuffle=False, num_workers=num_workers)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Models and Test"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "5"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: alexnet \t Top 1 Err: 0.470920 \t Top 5 Err: 0.233480 \n",
"Model: densenet121 \t Top 1 Err: 0.254400 \t Top 5 Err: 0.078040 \n",
"Model: densenet161 \t Top 1 Err: 0.223880 \t Top 5 Err: 0.060680 \n",
"Model: densenet169 \t Top 1 Err: 0.238380 \t Top 5 Err: 0.068240 \n",
"Model: densenet201 \t Top 1 Err: 0.231000 \t Top 5 Err: 0.065260 \n",
"Model: mobilenet0.25 \t Top 1 Err: 0.499980 \t Top 5 Err: 0.255280 \n",
"Model: mobilenet0.5 \t Top 1 Err: 0.380900 \t Top 5 Err: 0.159200 \n",
"Model: mobilenet0.75 \t Top 1 Err: 0.334320 \t Top 5 Err: 0.127500 \n",
"Model: mobilenet1.0 \t Top 1 Err: 0.297760 \t Top 5 Err: 0.103600 \n",
"Model: resnet101_v1 \t Top 1 Err: 0.233680 \t Top 5 Err: 0.066580 \n",
"Model: resnet101_v2 \t Top 1 Err: 0.227520 \t Top 5 Err: 0.063480 \n",
"Model: resnet152_v1 \t Top 1 Err: 0.230380 \t Top 5 Err: 0.064820 \n",
"Model: resnet152_v2 \t Top 1 Err: 0.219760 \t Top 5 Err: 0.060840 \n",
"Model: resnet18_v1 \t Top 1 Err: 0.336340 \t Top 5 Err: 0.126860 \n",
"Model: resnet18_v2 \t Top 1 Err: 0.310500 \t Top 5 Err: 0.113540 \n",
"Model: resnet34_v1 \t Top 1 Err: 0.294600 \t Top 5 Err: 0.098060 \n",
"Model: resnet34_v2 \t Top 1 Err: 0.274400 \t Top 5 Err: 0.089000 \n",
"Model: resnet50_v1 \t Top 1 Err: 0.249180 \t Top 5 Err: 0.074800 \n",
"Model: resnet50_v2 \t Top 1 Err: 0.240360 \t Top 5 Err: 0.071920 \n",
"Model: squeezenet1.0 \t Top 1 Err: 0.459460 \t Top 5 Err: 0.226700 \n",
"Model: squeezenet1.1 \t Top 1 Err: 0.477840 \t Top 5 Err: 0.237060 \n",
"Model: vgg11 \t Top 1 Err: 0.348600 \t Top 5 Err: 0.135040 \n",
"Model: vgg11_bn \t Top 1 Err: 0.327040 \t Top 5 Err: 0.121760 \n",
"Model: vgg13 \t Top 1 Err: 0.337020 \t Top 5 Err: 0.127020 \n",
"Model: vgg13_bn \t Top 1 Err: 0.328260 \t Top 5 Err: 0.120140 \n",
"Model: vgg16 \t Top 1 Err: 0.315920 \t Top 5 Err: 0.113100 \n",
"Model: vgg16_bn \t Top 1 Err: 0.297100 \t Top 5 Err: 0.099320 \n",
"Model: vgg19 \t Top 1 Err: 0.305620 \t Top 5 Err: 0.107440 \n",
"Model: vgg19_bn \t Top 1 Err: 0.289160 \t Top 5 Err: 0.096620 \n"
]
}
],
"source": [
"from mxnet.gluon.model_zoo.model_store import _model_sha1\n",
"\n",
"test_result = []\n",
"acc_top1 = mx.metric.Accuracy()\n",
"acc_top5 = mx.metric.TopKAccuracy(5)\n",
"\n",
"for model in sorted(_model_sha1.keys()):\n",
" if model == 'inceptionv3':\n",
" continue\n",
" net = models.get_model(model, pretrained=True, ctx=ctx)\n",
" acc_top1.reset()\n",
" acc_top5.reset() \n",
" for _, batch in enumerate(val_data): \n",
" data = gluon.utils.split_and_load(batch[0], ctx)\n",
" label = gluon.utils.split_and_load(batch[1], ctx) \n",
" outputs = [net(X) for X in data]\n",
" acc_top1.update(label, outputs)\n",
" acc_top5.update(label, outputs)\n",
" # print_str = 'Top 1 Err: %4f \\t Top 5 Err: %4f '%(1 - top1, 1 - top5)\n",
" # pbar.set_description(\"%s\" % print_str)\n",
" _, top1 = acc_top1.get()\n",
" _, top5 = acc_top5.get()\n",
" print('Model: %s \\t Top 1 Err: %4f \\t Top 5 Err: %4f '%(model, 1 - top1, 1 - top5))\n",
" test_result.append((model, 1 - top1, 1 - top5)) \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "6"
}
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>model</th>\n",
" <th>top_1_err</th>\n",
" <th>top_5_err</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>resnet152_v2</td>\n",
" <td>0.21976</td>\n",
" <td>0.06084</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>densenet161</td>\n",
" <td>0.22388</td>\n",
" <td>0.06068</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>resnet101_v2</td>\n",
" <td>0.22752</td>\n",
" <td>0.06348</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>resnet152_v1</td>\n",
" <td>0.23038</td>\n",
" <td>0.06482</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>densenet201</td>\n",
" <td>0.23100</td>\n",
" <td>0.06526</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" model top_1_err top_5_err\n",
"12 resnet152_v2 0.21976 0.06084\n",
"2 densenet161 0.22388 0.06068\n",
"10 resnet101_v2 0.22752 0.06348\n",
"11 resnet152_v1 0.23038 0.06482\n",
"4 densenet201 0.23100 0.06526"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"summary = pd.DataFrame(test_result, columns=['model', 'top_1_err', 'top_5_err'])\n",
"summary = summary.sort_values('top_1_err')\n",
"summary.head()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "7"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"| resnet152_v2 | 21.98 | 6.08 |\n",
"| densenet161 | 22.39 | 6.07 |\n",
"| resnet101_v2 | 22.75 | 6.35 |\n",
"| resnet152_v1 | 23.04 | 6.48 |\n",
"| densenet201 | 23.10 | 6.53 |\n",
"| resnet101_v1 | 23.37 | 6.66 |\n",
"| densenet169 | 23.84 | 6.82 |\n",
"| resnet50_v2 | 24.04 | 7.19 |\n",
"| resnet50_v1 | 24.92 | 7.48 |\n",
"| densenet121 | 25.44 | 7.80 |\n",
"| resnet34_v2 | 27.44 | 8.90 |\n",
"| vgg19_bn | 28.92 | 9.66 |\n",
"| resnet34_v1 | 29.46 | 9.81 |\n",
"| vgg16_bn | 29.71 | 9.93 |\n",
"| mobilenet1.0 | 29.78 | 10.36 |\n",
"| vgg19 | 30.56 | 10.74 |\n",
"| resnet18_v2 | 31.05 | 11.35 |\n",
"| vgg16 | 31.59 | 11.31 |\n",
"| vgg11_bn | 32.70 | 12.18 |\n",
"| vgg13_bn | 32.83 | 12.01 |\n",
"| mobilenet0.75 | 33.43 | 12.75 |\n",
"| resnet18_v1 | 33.63 | 12.69 |\n",
"| vgg13 | 33.70 | 12.70 |\n",
"| vgg11 | 34.86 | 13.50 |\n",
"| mobilenet0.5 | 38.09 | 15.92 |\n",
"| squeezenet1.0 | 45.95 | 22.67 |\n",
"| alexnet | 47.09 | 23.35 |\n",
"| squeezenet1.1 | 47.78 | 23.71 |\n",
"| mobilenet0.25 | 50.00 | 25.53 |\n"
]
}
],
"source": [
"for i, (model, top1, top5) in summary.iterrows():\n",
" print('| %s | %.2f | %.2f |'%(model, top1 * 100, top5 * 100))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "8"
}
},
"outputs": [],
"source": [
"summary.to_csv('model_zoo_test_result.csv', index=None)"
]
}
],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment