Created
April 26, 2017 19:54
-
-
Save zhreshold/bb9ddae7e3ba371e469b9084c5cccd8c to your computer and use it in GitHub Desktop.
Benchmark simulation for vgg with depth-wise convolution
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""References: | |
Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for | |
large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014). | |
""" | |
import mxnet as mx | |
def depthwise_conv(data, kernel, pad, num_filter, name, num_group): | |
conv = mx.symbol.Convolution(data=data, kernel=kernel, pad=pad, | |
num_filter=num_group, name=name+'_depthwise', num_group=num_group) | |
# bn = mx.symbol.BatchNorm(data=conv) | |
bn = conv # for benchmark | |
relu = mx.symbol.Activation(data=bn, act_type='relu') | |
conv2 = mx.symbol.Convolution(data=relu, kernel=(1, 1), num_filter=num_filter, | |
name=name+'_pointwise') | |
# bn2 = mx.symbol.BatchNorm(data=conv2) | |
bn2 = conv2 | |
return bn2 | |
def get_symbol(num_classes, **kwargs): | |
## define alexnet | |
data = mx.symbol.Variable(name="data") | |
# group 1 | |
conv1_1 = depthwise_conv(data=data, kernel=(3, 3), pad=(1, 1), num_filter=64, name="conv1_1", num_group=3) | |
relu1_1 = mx.symbol.Activation(data=conv1_1, act_type="relu", name="relu1_1") | |
pool1 = mx.symbol.Pooling( | |
data=relu1_1, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool1") | |
# group 2 | |
conv2_1 = depthwise_conv( | |
data=pool1, kernel=(3, 3), pad=(1, 1), num_filter=128, name="conv2_1", num_group=64) | |
relu2_1 = mx.symbol.Activation(data=conv2_1, act_type="relu", name="relu2_1") | |
pool2 = mx.symbol.Pooling( | |
data=relu2_1, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool2") | |
# group 3 | |
conv3_1 = depthwise_conv( | |
data=pool2, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_1", num_group=128) | |
relu3_1 = mx.symbol.Activation(data=conv3_1, act_type="relu", name="relu3_1") | |
conv3_2 = depthwise_conv( | |
data=relu3_1, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_2", num_group=256) | |
relu3_2 = mx.symbol.Activation(data=conv3_2, act_type="relu", name="relu3_2") | |
pool3 = mx.symbol.Pooling( | |
data=relu3_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool3") | |
# group 4 | |
conv4_1 = depthwise_conv( | |
data=pool3, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_1", num_group=256) | |
relu4_1 = mx.symbol.Activation(data=conv4_1, act_type="relu", name="relu4_1") | |
conv4_2 = depthwise_conv( | |
data=relu4_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_2", num_group=512) | |
relu4_2 = mx.symbol.Activation(data=conv4_2, act_type="relu", name="relu4_2") | |
pool4 = mx.symbol.Pooling( | |
data=relu4_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool4") | |
# group 5 | |
conv5_1 = depthwise_conv( | |
data=pool4, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_1", num_group=512) | |
relu5_1 = mx.symbol.Activation(data=conv5_1, act_type="relu", name="relu5_1") | |
conv5_2 = depthwise_conv( | |
data=relu5_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_2", num_group=512) | |
relu5_2 = mx.symbol.Activation(data=conv5_2, act_type="relu", name="conv1_2") | |
pool5 = mx.symbol.Pooling( | |
data=relu5_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool5") | |
# group 6 | |
flatten = mx.symbol.Flatten(data=pool5, name="flatten") | |
fc6 = mx.symbol.FullyConnected(data=flatten, num_hidden=4096, name="fc6") | |
relu6 = mx.symbol.Activation(data=fc6, act_type="relu", name="relu6") | |
drop6 = mx.symbol.Dropout(data=relu6, p=0.5, name="drop6") | |
# group 7 | |
fc7 = mx.symbol.FullyConnected(data=drop6, num_hidden=4096, name="fc7") | |
relu7 = mx.symbol.Activation(data=fc7, act_type="relu", name="relu7") | |
drop7 = mx.symbol.Dropout(data=relu7, p=0.5, name="drop7") | |
# output | |
fc8 = mx.symbol.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8") | |
softmax = mx.symbol.SoftmaxOutput(data=fc8, name='softmax') | |
return softmax |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment