Created
March 15, 2017 09:19
-
-
Save xmfbit/55135dd28eea8eed935853135aba3d5f to your computer and use it in GitHub Desktop.
convert yolo cfg file to caffe prototxt file. If there are different definitions of `reorg` and `region` layer params, change the code.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# convert yolo cfg file to caffe prototxt file | |
import sys | |
import argparse | |
TYPES = ["Conv", "ReLU", "Pool", "Route", "Reorg", "Region"] | |
ACTIVATION_TYPES = ['leaky', 'linear'] | |
layer_names = []; | |
def HasConflictNameError(conflicted_name): | |
print 'Error! The layer name \"{}\" has been in the list.'.format(conflicted_name) | |
print 'layer name list:' | |
print ', '.join(layer_names) | |
sys.exit(1) | |
def HasBracketDismatchError(content): | |
print 'in the content: {}, brackets are not found or dismatch'.format(content) | |
sys.exit(1) | |
def HasTypeDismatchError(content, type_string): | |
print 'in the content: {}, cannot find type: {}'.format(content, type_string) | |
sys.exit(1) | |
def get_substring_location(line, sub): | |
"""get start and end index of substring | |
`s, e = get_substring_location(line, sub)` | |
then,`line[s:e] == sub` | |
""" | |
pos = line.find(sub) | |
if pos == -1: | |
return -1, -1 | |
return pos, pos + len(sub) | |
class ConvParam(object): | |
def __init__(self, out_filters, kernel_size=3, stride=1, pad=0, train=False): | |
self.num_output = out_filters | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.pad = pad | |
def __str__(self): | |
return """convolution_param {{ | |
\t\tnum_output: {} | |
\t\tkernel_size: {} | |
\t\tpad: {} | |
\t\tstride: {} | |
\t}}""".format(self.num_output, self.kernel_size, self.pad, self.stride) | |
class Layer(object): | |
"""Base class for layers in the network""" | |
def __init__(self, section, idx): | |
self.idx = idx | |
# remove all the whitespace from the string | |
self.configs = map(lambda x: x.strip().replace(' ', ''), filter(None, section)) | |
self.check() | |
def check(self): | |
"""do some check""" | |
title = self.configs[0] | |
l, r = title.find('['), title.find(']') | |
if l == -1 or r == -1: | |
HasBracketDismatchError(title) | |
if title[l+1:r] != self.type_string: | |
HasTypeDismatchError(title, self.type_string) | |
def print_configs(self): | |
"""print configs info""" | |
print '=' * 40 | |
print '\n'.join(self.configs) | |
print '=' * 40 | |
def get_int_config(self, mark, config_name = None, necessary = True): | |
"""get config parameter which is `int` type | |
params: mark -> substring to find | |
config_name -> config parameter name. if not given, use `mark` | |
necessary -> if necessary and cannot find the desired config, | |
program exit directly | |
""" | |
for line in self.configs[1:]: | |
s, e = get_substring_location(line, mark) | |
if s != -1: | |
return int(line[e:]) | |
if necessary: | |
if config_name is None: | |
config_name = mark[:-1] | |
print 'Error! cannot find config `{}` in configs:'.format(config_name) | |
sys.exit(1) | |
else: | |
return 0 | |
def get_string_config(self, mark, config_name = None, necessary = True): | |
""""get config parameter which is `string` type""" | |
for line in self.configs[1:]: | |
s, e = get_substring_location(line, mark) | |
if s != -1: | |
return line[e:] | |
def __str__(self): | |
"""string representation""" | |
return "" | |
class DataLayer(object): | |
"""data layer""" | |
def __init__(self, height, width, batch_size=1, channels=3): | |
self.type_string = 'data' | |
self.name = 'data' | |
self.height = height | |
self.width = width | |
self.batch_size = batch_size | |
self.channels = channels | |
def __str__(self): | |
ret = """input: "data" | |
input_dim: {} | |
input_dim: {} | |
input_dim: {} | |
input_dim: {}\n""".format(self.batch_size, self.channels, self.height, self.width) | |
return ret | |
class ConvLayerUnit(Layer): | |
"""Conv-BN-ReLU uint""" | |
def __init__(self, section, idx, prev_layers): | |
"""initialization using block content""" | |
assert len(prev_layers) == 1 | |
self.type_string = 'convolutional' | |
Layer.__init__(self, section, idx) | |
self.bn = self.use_bn() | |
self.out_filters = self.get_filter_number() | |
self.stride = self.get_stride() | |
self.kernel_size = self.get_kernel_size() | |
self.pad = self.get_pad() | |
self.activation_type = self.get_activation_type() | |
self.name = 'conv_bn_relu_{}'.format(self.idx) | |
if isinstance(prev_layers[0], DataLayer): | |
self.input_blob_name = 'data' | |
else: | |
assert isinstance(prev_layers[0], Layer) | |
self.input_blob_name = prev_layers[0].output_blob_name | |
self.output_blob_name = 'conv_{}'.format(self.idx) | |
self.conv_name = 'conv_{}'.format(self.idx) | |
if self.conv_name in layer_names: | |
HasConflictNameError(self.conv_name) | |
layer_names.append(self.conv_name) | |
if self.bn: | |
self.bn_name = 'bn_{}'.format(self.conv_name) | |
if self.bn_name in layer_names: | |
HasConflictNameError(self.bn_name) | |
layer_names.append(self.bn_name) | |
self.scale_name = 'scale_{}'.format(self.conv_name) | |
if self.scale_name in layer_names: | |
HasConflictNameError(self.scale_name) | |
layer_names.append(self.scale_name) | |
self.relu_name = 'relu_{}'.format(self.conv_name) | |
if self.relu_name in layer_names: | |
HasConflictNameError(self.relu_name) | |
layer_names.append(self.relu_name) | |
def use_bn(self): | |
"""use batch normalization or not""" | |
ret = self.get_int_config('batch_normalize=', necessary=False) | |
assert ret == 0 or ret == 1 | |
return ret == 1 | |
def get_filter_number(self): | |
"""get output filter number""" | |
return self.get_int_config('filters=') | |
def get_kernel_size(self): | |
"""get kernel size of the filter""" | |
return self.get_int_config('size=') | |
def get_stride(self): | |
"""get stride""" | |
return self.get_int_config('stride=') | |
def get_pad(self): | |
"""get padding size""" | |
return self.get_int_config('pad=') | |
def get_activation_type(self): | |
"""get non-linear activation type""" | |
ret = self.get_string_config('activation=') | |
assert ret in ACTIVATION_TYPES | |
return ret | |
def str_conv_layer(self): | |
param = ConvParam(self.out_filters, self.kernel_size, self.stride, self.pad, False) | |
ret = """layer {{ | |
\tname: "{}" | |
\ttype: "Convolution" | |
\tbottom: "{}" | |
\ttop: "{}" | |
\t{} | |
}}""".format(self.input_blob_name, self.output_blob_name, self.conv_name, str(param)); | |
return ret; | |
def str_relu_layer(self): | |
if self.activation_type == 'leaky': | |
relu_param = """relu_param{ | |
\t\tnegative_slope: 0.1 | |
\t}\n""" | |
elif self.activation_type == 'linear': | |
relu_param = "" | |
ret = """layer {{ | |
\tname: "{}" | |
\ttype: "ReLU" | |
\tbottom: "{}" | |
\ttop: "{}" | |
\t{} | |
}}\n""".format(self.relu_name, self.output_blob_name, self.output_blob_name, relu_param) | |
return ret | |
def str_bn_layer(self): | |
if not self.bn: | |
return "" | |
bn_str = """layer {{ | |
\tname: "{}" | |
\ttype: "BatchNorm" | |
\tbottom: "{}" | |
\ttop: "{}" | |
}}\n""".format(self.bn_name, self.output_blob_name, self.output_blob_name) | |
scale_str = """layer {{ | |
\tname: "{}" | |
\ttype: "Scale" | |
\tbottom: "{}" | |
\ttop: "{}" | |
\tscale_param {{ | |
\t\tbias_term: true | |
\t}} | |
}}\n""".format(self.scale_name, self.output_blob_name, self.output_blob_name) | |
return bn_str + scale_str | |
def __str__(self): | |
"""the presentation in protobuf style""" | |
start = "\n#--- {} --- start\n".format(self.conv_name) | |
end = '#--- {} --- end\n'.format(self.conv_name) | |
return start + self.str_conv_layer() + '\n' + self.str_bn_layer() + '\n' + self.str_relu_layer() + end | |
class Poolayer(Layer): | |
def __init__(self, section, idx, prev_layers): | |
assert len(prev_layers) == 1 | |
self.type_string = 'maxpool' | |
Layer.__init__(self, section, idx) | |
self.name = 'pool_{}'.format(idx) | |
self.input_blob_name = prev_layers[0].output_blob_name | |
self.output_blob_name = 'pool_{}'.format(idx) | |
self.stride = self.get_int_config('stride=') | |
self.kernel_size = self.get_int_config('size=') | |
def __str__(self): | |
return """\nlayer {{ | |
\tname: "{}" | |
\ttype: "Pooling" | |
\tbottom: "{}" | |
\ttop: "{}" | |
\tpooling_param {{ | |
\t\tpool: MAX | |
\t\tkernel_size: {} | |
\t\tstride: {} | |
\t}} | |
}}\n""".format(self.name, self.input_blob_name, self.output_blob_name, self.kernel_size, self.stride) | |
class ReorgLayer(Layer): | |
def __init__(self, section, idx, prev_layers): | |
assert len(prev_layers) == 1 | |
self.type_string = 'reorg' | |
Layer.__init__(self, section, idx) | |
self.name = 'reorg_{}'.format(idx) | |
self.input_blob_name = prev_layers[0].output_blob_name | |
self.output_blob_name = 'reorg_{}'.format(self.input_blob_name) | |
self.stride = self.get_int_config('stride=') | |
self.reverse = False | |
def __str__(self): | |
reverse_item = '\n\t\treverse:true\n' if self.reverse else '' | |
return """\nlayer {{ | |
\tname: "{}" | |
\ttype: "Reorg" | |
\tbottom: "{}" | |
\ttop: "{}" | |
\treorg_param {{ | |
\t\tstride: 2{} | |
\t}} | |
}}\n""".format(self.name, self.input_blob_name, self.output_blob_name, reverse_item) | |
class RouteLayer(Layer): | |
def __init__(self, section, idx, prev_layers): | |
self.type_string = 'route' | |
Layer.__init__(self, section, idx) | |
self.name = 'concat_{}'.format(idx) | |
concated_layers = [] | |
layer_idx = map(lambda x: int(x), self.get_string_config('layers=').split(',')) | |
for i in layer_idx: | |
concated_layers.append(prev_layers[i]) | |
self.input_blob_name = [x.output_blob_name for x in concated_layers] | |
self.output_blob_name = 'concat_{}'.format(idx) | |
def __str__(self): | |
base = """\nlayer {{ | |
\tname: "{}" | |
\ttype: "Concat\n""".format(self.name) | |
for x in self.input_blob_name: | |
bottom = '\tbottom: \"{}\"\n'.format(x) | |
base += bottom | |
top = """\ttop: "{}" | |
}}\n""".format(self.output_blob_name) | |
return base + top | |
class RegionLayer(Layer): | |
def __init__(self, section, prev_layers): | |
assert len(prev_layers) == 1 | |
self.type_string = 'region' | |
Layer.__init__(self, section, 0) | |
self.name = 'region' | |
self.input_blob_name = prev_layers[0].output_blob_name | |
self.output_blob_name = 'region' | |
self.classes = self.get_int_config('classes=') | |
self.coords = self.get_int_config('coords=') | |
self.boxes_of_each_grid = self.get_int_config('num=') | |
self.softmax = self.get_int_config('softmax=') == 1 | |
self.thresh = float(self.get_string_config('thresh=')) | |
self.anchors = self.get_anchors() | |
def get_anchors(self): | |
anchors = map(lambda x: float(x), self.get_string_config('anchors=').split(',')) | |
assert len(anchors) == self.boxes_of_each_grid * 2 | |
return anchors | |
def __str__(self): | |
softmax_item = '\n\t\tsoftmax: true\n' if self.softmax else '' | |
anchor_item = "" | |
for i in xrange(self.boxes_of_each_grid): | |
pw, ph = self.anchors[2*i], self.anchors[2*i+1] | |
anchor_item += '\t\tanchor_coords {{pw: {} ph: {}}}\n'.format(pw, ph) | |
base = """\nlayer {{ | |
\tname: "{}" | |
\ttype: "Region" | |
\tbottom: "{}" | |
\ttop: "{}" | |
\tregion_param {{ | |
\t\tclasses: {} | |
\t\tcoords: {} | |
\t\tboxes_of_each_grid: {}{}\t\tthres: {} | |
{} | |
\t}} | |
}}\n""".format(self.name, self.input_blob_name, self.output_blob_name, | |
self.classes, self.coords, self.boxes_of_each_grid, softmax_item, | |
self.thresh, anchor_item) | |
return base | |
class Net(object): | |
def __init__(self, sections, net_name): | |
self.name = net_name | |
self.sections = sections | |
self.layer_list = [DataLayer(416, 416, 1, 3)] | |
self.conv_count = 0 | |
self.pooling_count = 0 | |
self.reorg_count = 0 | |
self.route_count = 0 | |
i = 0 | |
while i < len(sections): | |
if sections[i].find('[convolutional]') != -1: | |
# find conv-bn-relu unit | |
# we try to find another conv unit | |
j = self.find_next_layer(i) | |
if j == -1: | |
print 'Error! conv cannot be the last layer in the net.' | |
sys.exit(1) | |
self.layer_list.append(ConvLayerUnit(sections[i:j], | |
self.conv_count+1, [self.layer_list[-1]])) | |
self.conv_count += 1 | |
i = j | |
continue | |
if sections[i].find('[maxpool]') != -1: | |
# find kax pooling layer | |
# we try to find the next one | |
j = self.find_next_layer(i) | |
if j == -1: | |
print 'Error! maxpool cannot be the last layer in the net.' | |
sys.exit(1) | |
self.layer_list.append(Poolayer(sections[i:j], | |
self.pooling_count+1, [self.layer_list[-1]])) | |
self.pooling_count += 1 | |
i = j | |
continue | |
if sections[i].find('[reorg]') != -1: | |
# find reorg layer | |
j = self.find_next_layer(i) | |
if j == -1: | |
print 'Error! reorg layer cannot be the last layer in the net.' | |
sys.exit(1) | |
self.layer_list.append(ReorgLayer(sections[i:j], | |
self.reorg_count+1, [self.layer_list[-1]])) | |
self.reorg_count += 1 | |
i = j | |
continue | |
if sections[i].find('[route]') != -1: | |
j = self.find_next_layer(i) | |
if j == -1: | |
print 'Error! route layer cannot be the last layer in the net.' | |
sys.exit(1) | |
route_layer = RouteLayer(sections[i:j], | |
self.route_count+1, self.layer_list) | |
self.layer_list.append(route_layer) | |
self.route_count += 1 | |
i = j | |
continue | |
if sections[i].find('[region]') != -1: | |
j = self.find_next_layer(i) | |
if j != -1: | |
print 'Error! region layer should be the last layer in the net.' | |
sys.exit(1) | |
self.layer_list.append(RegionLayer(sections[i:j], [self.layer_list[-1]])) | |
break | |
i += 1 | |
def find_next_layer(self, start): | |
idx = start + 1 | |
while idx < len(self.sections): | |
if self.sections[idx].find('[') != -1: | |
break | |
idx += 1 | |
if idx == len(self.sections): | |
print 'Reach the end of the file.' | |
return -1 | |
return idx | |
def __str__(self): | |
ret = 'name: \"{}\"\n'.format(self.name) | |
for layer in self.layer_list: | |
ret += str(layer) | |
return ret | |
def parse_args(): | |
"""parse input arguments""" | |
parser = argparse.ArgumentParser(description='yolo cfg -> caffe prototxt') | |
parser.add_argument('--cfg', dest='cfg_file', help='file name of yolo cfg file') | |
parser.add_argument('--out', dest='out_file', help='file name of generated caffe prototxt') | |
args = parser.parse_args() | |
return args | |
def main(): | |
args = parse_args() | |
yolo_cfg = args.cfg_file | |
protofile = args.out_file | |
with open(yolo_cfg, 'r') as f: | |
sections = f.readlines() | |
sections = filter(lambda x: x != '\n', sections) | |
net = Net(sections, 'net') | |
with open(protofile, 'w') as f: | |
f.write('# auto generated by convert.py\n') | |
f.write(str(net)) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment