Skip to content

Instantly share code, notes, and snippets.

@BeMg
Created October 30, 2019 13:07
Show Gist options
  • Save BeMg/83845c453e46032bf69ad640ba90df1e to your computer and use it in GitHub Desktop.
Save BeMg/83845c453e46032bf69ad640ba90df1e to your computer and use it in GitHub Desktop.
rearrange imagenet dataset
import mxnet as mx
import numpy as np
from mxnet import gluon, nd
from mxnet import autograd as ag
from mxnet.gluon import nn
from mxnet.gluon.model_zoo import vision as models
from mxnet.gluon.data import vision
from mxnet.gluon.data.vision import transforms
import shutil
import time
import os
import pandas as pd
val_path = 'val'
image_path = '/home/bemg/dataset/imagenet_img' #use absolute address
synsets_file = open('./synsets.txt', 'r')
val_file = open('./val.txt', 'r')
if os.path.exists(val_path):
shutil.rmtree(val_path)
synsets = [line.rstrip('\n') for line in synsets_file.readlines()]
for line in val_file.readlines():
fname, idx = line.split()
label_path = '%s/%s' % (val_path, synsets[int(idx)])
if not os.path.exists(label_path):
os.makedirs(label_path)
os.symlink('%s/%s' % (image_path, fname), '%s/%s' % (label_path, fname))
transform_test = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
batch_size = 1
num_gpus = 8
val_total = 50000
ctx = [mx.cpu()]
num_workers = 10
val_dataset = vision.ImageFolderDataset(val_path)
val_data = gluon.data.DataLoader(
val_dataset.transform_first(transform_test),
batch_size=batch_size, shuffle=False, num_workers=num_workers)
from mxnet.gluon.model_zoo.model_store import _model_sha1
test_result = []
acc_top1 = mx.metric.Accuracy()
acc_top5 = mx.metric.TopKAccuracy(5)
cnt = 0
for model in sorted(_model_sha1.keys()):
if model == 'inceptionv3':
continue
net = models.get_model('mobilenetv2_1.0', pretrained=True, ctx=ctx)
acc_top1.reset()
acc_top5.reset()
for _, batch in enumerate(val_data):
print(cnt)
cnt += 1
# print(batch[0])
# print(batch[1])
data = gluon.utils.split_and_load(batch[0], ctx)
label = gluon.utils.split_and_load(batch[1], ctx)
outputs = [net(X) for X in data]
acc_top1.update(label, outputs)
acc_top5.update(label, outputs)
# print_str = 'Top 1 Err: %4f \t Top 5 Err: %4f '%(1 - top1, 1 - top5)
# pbar.set_description("%s" % print_str)
_, top1 = acc_top1.get()
_, top5 = acc_top5.get()
print('Model: %s \t Top 1 Err: %4f \t Top 5 Err: %4f '%('mobilenetv2_1.0', 1 - top1, 1 - top5))
test_result.append((model, 1 - top1, 1 - top5))
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment