Last active
March 12, 2019 02:13
-
-
Save zhreshold/850fcf0444121422144388a231f81aec to your computer and use it in GitHub Desktop.
YOLOv3 converter from darknet to GluonCV
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import os, sys | |
import gluoncv as gcv | |
import mxnet as mx | |
shapes_ref = [] | |
with open('ref.txt', 'r') as fid: | |
for line in fid: | |
pos = line.find('[') | |
pos2 = line.find(']') | |
t = line[pos+1:pos2] | |
s = [int(x) for x in t.split(',')] | |
shapes_ref.append(s) | |
net = gcv.model_zoo.get_model('yolo3_416_darknet53_coco', pretrained_base=False) | |
net.initialize() | |
x = mx.nd.zeros((1, 3, 416, 416)) | |
net(x) | |
NUM = 366 | |
keys = {} | |
orders = ['.*darknet', '.*yolodetectionblockv30', '.*yolooutputv30', '.*yolov30_conv0|.*yolov30_batchnorm0', | |
'.*yolodetectionblockv31', '.*yolooutputv31', '.*yolov30_conv1|.*yolov30_batchnorm1', '.*yolodetectionblockv32', '.*yolooutputv32',] | |
count = 0 | |
for select in orders: | |
for k, v in net.collect_params(select=select).items(): | |
if 'offset' in k or 'anchor' in k: | |
print('skip:', k) | |
continue | |
if 'yolooutput' in select: | |
if 'weight' in k: | |
assert count + 1 not in keys, "{} already exists, {}".format(count + 1, k) | |
keys[count + 1] = (k, v) | |
elif 'bias' in k: | |
assert count - 1 not in keys, "{} already exists, {}".format(count - 1, k) | |
keys[count - 1] = (k, v) | |
else: | |
raise RuntimeError('invalid:{}'.format(k)) | |
else: | |
if 'conv' in k and 'weight' in k: | |
assert count + 4 not in keys, "{} already exists, {}".format(count + 4, k) | |
keys[count + 4] = (k, v) | |
elif 'beta' in k: | |
assert count -2 not in keys, "{} already exists, {}".format(count -2, k) | |
keys[count -2] = (k, v) | |
elif 'gamma' in k: | |
assert count not in keys, "{} already exists, {}".format(count, k) | |
keys[count] = (k, v) | |
else: | |
assert count-1 not in keys, "{} already exists, {}".format(count-1, k) | |
keys[count-1] = (k, v) | |
count += 1 | |
print(len(list(keys))) | |
ptr = 0 | |
for i in range(400): | |
try: | |
print(i, keys[i][0], keys[i][1].shape, tuple(shapes_ref[ptr])) | |
assert tuple(shapes_ref[ptr]) == keys[i][1].shape, '{}, {}'.format(keys[i][1].shape, tuple(shapes_ref[ptr])) | |
ptr += 1 | |
except KeyError: | |
pass | |
with open('yolov3.weights', 'rb') as fp: | |
header = np.fromfile(fp, dtype = np.int32, count = 5) | |
print(header) | |
weights = np.fromfile(fp, dtype=np.float32) | |
print(len(weights)) | |
ptr = 0 | |
for i in range(400): | |
if i not in keys: | |
continue | |
shape = keys[i][1].shape | |
name = keys[i][0] | |
offset = np.prod(shape) | |
raw_data = weights[ptr:ptr+offset] | |
ptr += offset | |
before = keys[i][1].data().mean().asscalar() | |
keys[i][1].set_data(mx.nd.array(np.array(raw_data).reshape(shape))) | |
after = keys[i][1].data().mean().asscalar() | |
print(name, shape, before, after) | |
assert len(weights) == ptr | |
print(ptr) | |
net.save_parameters('yolo3_416_darknet53_coco-converted.params') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment