Last active
February 6, 2021 02:59
-
-
Save woolpeeker/129a1061f0487ee71815b76d01eb6af2 to your computer and use it in GitHub Desktop.
merge bn for mxnet symbol
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import mxnet as mx | |
from mxnet import ndarray as nd | |
from easydict import EasyDict as edict | |
import json | |
import copy | |
SRC_SYM = './models/vgg-softmax-emore/model' | |
TGT_SYM = './models/vgg-softmax-emore-bnMerged/model' | |
def merge_weight(kernel, bias=None, beta=None, gamma=None, running_mean=None, running_var=None, eps=0.001): | |
var_sqrt = nd.sqrt(running_var + eps) | |
if bias is None: | |
bias = nd.zeros_like(running_mean.shape) | |
kernel = kernel * (gamma / var_sqrt).reshape([kernel.shape[0], 1, 1, 1]) | |
bias = (bias - running_mean)/var_sqrt * gamma + beta | |
return kernel, bias | |
if __name__ == '__main__': | |
_r, arg_params, aux_params = mx.model.load_checkpoint(SRC_SYM, 1) | |
net = json.loads(_r.tojson()) | |
net.keys() | |
new_net = copy.deepcopy(net) | |
new_arg_params = copy.deepcopy(arg_params) | |
new_aux_params = copy.deepcopy(aux_params) | |
nodes = new_net['nodes'] | |
for node in nodes: | |
if node['op'] == 'BatchNorm': | |
inp_node = nodes[node['inputs'][0][0]] | |
gamma_node = nodes[node['inputs'][1][0]] | |
beta_node = nodes[node['inputs'][2][0]] | |
mean_node = nodes[node['inputs'][3][0]] | |
var_node = nodes[node['inputs'][4][0]] | |
gamma = arg_params[gamma_node['name']] | |
beta = arg_params[beta_node['name']] | |
running_mean = aux_params[mean_node['name']] | |
running_var = aux_params[var_node['name']] | |
if inp_node['op'] == 'Convolution': | |
kernel_node = nodes[inp_node['inputs'][1][0]] | |
kernel = arg_params[kernel_node['name']] | |
assert len(inp_node['inputs']) == 3, 'conv must have bias, otherwise we need add node' | |
bias_node = nodes[inp_node['inputs'][2][0]] | |
bias = arg_params[bias_node['name']] | |
new_kernel, new_bias = merge_weight(kernel, bias, gamma, beta, running_mean, running_var) | |
new_arg_params[kernel_node['name']] = new_kernel | |
new_arg_params[bias_node['name']] = new_bias | |
else: | |
raise Exception('Unsupported layer type for bn merging: ', inp_node['op']) | |
node['op'] = '_copy' | |
node['inputs'] = [node['inputs'][0]] | |
node['name'] = node['name'] + '_merged' | |
json_str = json.dumps(new_net) | |
merged_net = mx.symbol.load_json(json_str) | |
mx.model.save_checkpoint( | |
prefix=TGT_SYM, | |
epoch=1, | |
symbol=merged_net, | |
arg_params=new_arg_params, | |
aux_params={} | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment