Skip to content

Instantly share code, notes, and snippets.

@kice
Created November 20, 2018 05:25
Show Gist options
  • Save kice/972608997df10de3aa3c8b5c21606e8d to your computer and use it in GitHub Desktop.
Save kice/972608997df10de3aa3c8b5c21606e8d to your computer and use it in GitHub Desktop.
import warnings
warnings.filterwarnings("ignore")
import mxnet as mx
import numpy as np
from skimage.measure import compare_ssim, compare_psnr
from PIL import Image
from mxnet.contrib.quantization import *
import sys, os, argparse, time, ntpath, logging, logging.handlers
def save_params(fname, arg_params, aux_params):
save_dict = {('arg:%s' % k): v.as_in_context(cpu()) for k, v in arg_params.items()}
save_dict.update({('aux:%s' % k): v.as_in_context(cpu()) for k, v in aux_params.items()})
mx.nd.save(fname, save_dict)
def forward(model, data, ctx):
""" Perforamce a forward with data
Parameters
----------
model : mx.mod.Module
data : numpy.array
cxt : mx.Context
"""
from collections import namedtuple
data = mx.nd.expand_dims(mx.nd.array(data), axis=0)
data = mx.nd.transpose(data, axes=(0, 3, 1, 2)).astype('float32')
Batch = namedtuple('Batch', ['data'])
model.forward(data_batch=Batch([data]), is_train=False)
pred = mx.nd.transpose(model.get_outputs()[0], axes=(0, 2, 3, 1)).asnumpy().astype("float")
pred = np.squeeze(pred, axis=0)
return pred
def eval(name, out, sym_json, params, epoch, benchmark=False, comment='', use_monger=True):
############################
# prepare image for forward
############################
orginal = Image.open(name)
#orginal = orginal.resize((640, 360), Image.BICUBIC)
w, h = orginal.size
if False:
orginal= orginal.convert(mode="L")
#import preprocess
if benchmark:
if w % scale != 0 or h % scale != 0:
orginal = orginal.crop((0, 0, w // scale * scale, h // scale * scale))
w, h = orginal.size
lr = np.array(orginal.resize((w // scale, h // scale), Image.BICUBIC)) / 255.0
#_, lr, _ = preprocess.preprocess(np.array(orginal).astype(np.float) / 255.0)
else:
# lr, _, _ = preprocess.preprocess(np.array(orginal).astype(np.float) / 255.0)
lr = np.array(orginal).astype("float") / 255.0
org = np.array(orginal).astype("float") / 255.0
if len(org.shape) == 2:
org = np.expand_dims(org, axis=2)
print("org.shape: ", org.shape)
#mean = lr.mean(axis=(0, 1))
#lr = (lr - mean) * 2.0
img = lr
print("input image:", img.shape)
if gpu == -1:
ctx = mx.cpu()
else:
ctx = mx.gpu(gpu)
logging.basicConfig()
logger = logging.getLogger('logger')
logger.setLevel(logging.INFO)
############################
# Load param and symbol
############################
net = mx.symbol.load(sym_json)
args = mx.nd.load(params)
# Load param
arg_param = {}
aux_param = {}
for k, v in args.items():
if k.find("arg") != -1:
arg_param[k.split(":")[1]] = v
if k.find("aux") != -1:
aux_param[k.split(":")[1]] = v
excluded_sym_names = ['convolution0', 'convolution19']
calib_mode = 'none'
quantized_dtype = 'int8'
cqsym, qarg_params, aux_params = quantize_model(sym=net, arg_params=arg_param, aux_params=aux_param,
ctx=ctx, excluded_sym_names=excluded_sym_names,
calib_mode=calib_mode, quantized_dtype=quantized_dtype,
logger=logger)
cqsym.save('laopo2x_no_noise_int8-symbol.json')
save_params('laopo2x_no_noise_int8-0000.params', qarg_params, aux_params)
net = cqsym
arg_param = qarg_params
aux_param = aux_params
# for k,v in arg_param.items():
# print(v.dtype)
if len(img.shape) == 2:
img = np.expand_dims(img, axis=2)
dshape = (1, img.shape[2], img.shape[0], img.shape[1])
print("Forward data_shape=", dshape)
# arg_param['data'] = mx.nd.ones([1,3,1080,1920])
# #ex = net.simple_bind(mx.cpu(), grad_req='null', type_dict={'data' : np.float32}, data=dshape)
# ex = net.bind(mx.cpu(), arg_param)
# ex.forward()
# print(ex.outputs[0].asnumpy().shape)
# return
model = mx.mod.Module(net, context=ctx, data_names=['data'])
model.bind(data_shapes=[('data', dshape)], for_training=False, grad_req='null')
model.set_params(arg_params=arg_param, aux_params=aux_param)
# Start forward
start = time.clock()
output = forward(model, img, ctx)
print("output.shape:", output.shape)
#output = output / 2.0 + mean
#output = preprocess.reconstruct_sigmoid(output)
#output = (output + mean) / 2.0
output = np.maximum(np.minimum(output, 1.0), 0.0)
psnr = float('nan')
ssim = float('nan')
if benchmark:
print(org.shape, " vs. ", output.shape)
psnr = compare_psnr(org, output, data_range=1)
ssim = compare_ssim(org, output, data_range=1, multichannel=True)
print("saving...")
output = Image.fromarray((output * 255.0).astype(np.uint8), 'RGB')
output.save(fp=out, compress_level=9)
print('finshed in %.2fs psnr: %.2f dB ssim: %.4f'%(time.clock() - start, psnr, ssim))
input_file = '2631_x2_HR.png'
benchmark = True
gpu = 0
scale = 2
epoch = 0
network = "laopo2x_no_noise"
param = "./%s-%04d.params"%(network, 0)
sym ="./%s-symbol.json"%(network)
name, ext = os.path.splitext(input_file)
out = "./%s_x%d_%s_%d.png"%(ntpath.basename(name), scale, network, epoch)
eval(input_file, out, sym, param, epoch, benchmark, comment="%s"%ntpath.basename(input_file))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment