Skip to content

Instantly share code, notes, and snippets.

@woolpeeker
Last active February 6, 2021 02:59
Show Gist options
  • Save woolpeeker/129a1061f0487ee71815b76d01eb6af2 to your computer and use it in GitHub Desktop.
Save woolpeeker/129a1061f0487ee71815b76d01eb6af2 to your computer and use it in GitHub Desktop.
merge bn for mxnet symbol
import os
import mxnet as mx
from mxnet import ndarray as nd
from easydict import EasyDict as edict
import json
import copy
SRC_SYM = './models/vgg-softmax-emore/model'
TGT_SYM = './models/vgg-softmax-emore-bnMerged/model'
def merge_weight(kernel, bias=None, beta=None, gamma=None, running_mean=None, running_var=None, eps=0.001):
var_sqrt = nd.sqrt(running_var + eps)
if bias is None:
bias = nd.zeros_like(running_mean.shape)
kernel = kernel * (gamma / var_sqrt).reshape([kernel.shape[0], 1, 1, 1])
bias = (bias - running_mean)/var_sqrt * gamma + beta
return kernel, bias
if __name__ == '__main__':
_r, arg_params, aux_params = mx.model.load_checkpoint(SRC_SYM, 1)
net = json.loads(_r.tojson())
net.keys()
new_net = copy.deepcopy(net)
new_arg_params = copy.deepcopy(arg_params)
new_aux_params = copy.deepcopy(aux_params)
nodes = new_net['nodes']
for node in nodes:
if node['op'] == 'BatchNorm':
inp_node = nodes[node['inputs'][0][0]]
gamma_node = nodes[node['inputs'][1][0]]
beta_node = nodes[node['inputs'][2][0]]
mean_node = nodes[node['inputs'][3][0]]
var_node = nodes[node['inputs'][4][0]]
gamma = arg_params[gamma_node['name']]
beta = arg_params[beta_node['name']]
running_mean = aux_params[mean_node['name']]
running_var = aux_params[var_node['name']]
if inp_node['op'] == 'Convolution':
kernel_node = nodes[inp_node['inputs'][1][0]]
kernel = arg_params[kernel_node['name']]
assert len(inp_node['inputs']) == 3, 'conv must have bias, otherwise we need add node'
bias_node = nodes[inp_node['inputs'][2][0]]
bias = arg_params[bias_node['name']]
new_kernel, new_bias = merge_weight(kernel, bias, gamma, beta, running_mean, running_var)
new_arg_params[kernel_node['name']] = new_kernel
new_arg_params[bias_node['name']] = new_bias
else:
raise Exception('Unsupported layer type for bn merging: ', inp_node['op'])
node['op'] = '_copy'
node['inputs'] = [node['inputs'][0]]
node['name'] = node['name'] + '_merged'
json_str = json.dumps(new_net)
merged_net = mx.symbol.load_json(json_str)
mx.model.save_checkpoint(
prefix=TGT_SYM,
epoch=1,
symbol=merged_net,
arg_params=new_arg_params,
aux_params={}
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment