Last active
February 6, 2018 11:30
-
-
Save nlgranger/18adbea7fac3ddb944d0ad679709c1de to your computer and use it in GitHub Desktop.
Wide ResNet in Lasagne/Theano
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from lasagne.layers import BatchNormLayer, NonlinearityLayer, Conv2DLayer, \ | |
DropoutLayer, ElemwiseSumLayer, GlobalPoolLayer | |
from lasagne.nonlinearities import rectify | |
from lasagne.init import HeNormal | |
def wide_resnet(l_in, d, k, dropout=0.): | |
"""Build a Wide-Resnet WRN-d-k [Zagoruyko2016]_ | |
Parameters | |
---------- | |
:param l_in: | |
input Layer | |
:param d: | |
network depth (d follow the relation d = 6 * n + 4 where n is the number of blocs | |
by groups) | |
:param k: | |
widening factor | |
:param dropout: | |
dropout rate | |
.. [Zagoruyko2016] Zagoruyko, S., & Komodakis, N. (2016). | |
Wide Residual Networks. In Proceedings of the British Machine Vision | |
Conference 2016, BMVC 2016, York, UK, September 19-22, 2016. | |
""" | |
if (d - 4) % 6 != 0: | |
raise ValueError("d should be of the form d = 6 * n + 4") | |
n = (d - 4) // 6 | |
he_norm = HeNormal(gain='relu') | |
def basic_block(incoming, num_filters, stride, shortcut, name=None): | |
name = name + "_" if name is not None else "" | |
conv_path = BatchNormLayer(incoming) | |
conv_path = NonlinearityLayer(conv_path, nonlinearity=rectify) | |
rectified_input = conv_path # reused in linear shortcut | |
# TODO: not clear if we should dropout here, authors code doesn't seem to | |
conv_path = Conv2DLayer( | |
conv_path, num_filters=num_filters, filter_size=(3, 3), | |
stride=stride, pad='same', | |
W=he_norm, b=None, nonlinearity=None, name=name + "conv1") | |
conv_path = BatchNormLayer(conv_path) | |
conv_path = NonlinearityLayer(conv_path, nonlinearity=rectify) | |
if dropout > 0: | |
conv_path = DropoutLayer(conv_path, p=dropout) | |
conv_path = Conv2DLayer( | |
conv_path, num_filters=num_filters, filter_size=(3, 3), | |
pad='same', | |
W=he_norm, b=None, nonlinearity=None, name=name + "conv2") | |
if shortcut == 'identity': | |
assert stride == (1, 1) or stride == 1 | |
short_path = incoming | |
elif shortcut == 'linear': | |
short_path = Conv2DLayer( | |
rectified_input, num_filters=num_filters, filter_size=(1, 1), | |
stride=stride, pad='same', | |
W=he_norm, b=None, nonlinearity=None) | |
else: | |
raise ValueError("invalid parameter value for shortcut") | |
o = ElemwiseSumLayer([conv_path, short_path], name=name + "sum") | |
return o | |
net = Conv2DLayer( | |
l_in, num_filters=16, filter_size=(3, 3), | |
pad='same', | |
W=he_norm, b=None, nonlinearity=None) | |
net = basic_block(net, 16 * k, stride=(1, 1), shortcut='linear', | |
name="block11") | |
for i in range(1, n): | |
net = basic_block(net, 16 * k, stride=(1, 1), shortcut='identity', | |
name="block1" + str(i + 1)) | |
net = basic_block(net, 32 * k, stride=(2, 2), shortcut='linear', | |
name="block21") | |
for i in range(1, n): | |
net = basic_block(net, 32 * k, stride=(1, 1), shortcut='identity', | |
name="block2" + str(i + 1)) | |
net = basic_block(net, 64 * k, stride=(2, 2), shortcut='linear', | |
name="block31") | |
for i in range(1, n): | |
net = basic_block(net, 64 * k, stride=(1, 1), shortcut='identity', | |
name="block3" + str(i + 1)) | |
net = BatchNormLayer(net) | |
net = NonlinearityLayer(net, nonlinearity=rectify) | |
net = GlobalPoolLayer(net, name="MeanPool") | |
return net |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment