Last active
May 17, 2017 13:39
-
-
Save mitmul/57c3ba5fb0bf64b6c503 to your computer and use it in GitHub Desktop.
Deep Residual Network definition by Chainer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import math | |
import chainer | |
import chainer.links as L | |
import chainer.functions as F | |
class BottleNeck(chainer.Chain): | |
def __init__(self, in_size, ch1, ch2, ch3, stride=1, ksize=1): | |
w = math.sqrt(2) | |
super(BottleNeck, self).__init__( | |
conv1=L.Convolution2D(in_size, ch1, ksize, stride, 0, w), | |
bn1=L.BatchNormalization(ch1), | |
conv2=L.Convolution2D(ch1, ch2, 3, 1, 1, w), | |
bn2=L.BatchNormalization(ch2), | |
conv3=L.Convolution2D(ch2, ch3, 1, 1, 0, w), | |
bn3=L.BatchNormalization(ch3), | |
) | |
def __call__(self, x, train): | |
h = F.relu(self.bn1(self.conv1(x), test=not train)) | |
h = F.relu(self.bn2(self.conv2(h), test=not train)) | |
h = self.bn3(self.conv3(h), test=not train) | |
if x.data.shape != h.data.shape: | |
xp = chainer.cuda.get_array_module(x.data) | |
n, c, hh, ww = x.data.shape | |
pad_c = h.data.shape[1] - c | |
p = xp.zeros((n, pad_c, hh, ww), dtype=xp.float32) | |
p = chainer.Variable(p, volatile=not train) | |
x = F.concat((p, x)) | |
if x.data.shape[2:] != h.data.shape[2:]: | |
x = F.average_pooling_2d(x, 1, 2) | |
return F.relu(h + x) | |
class ResNet(chainer.Chain): | |
insize = 224 | |
def __init__(self, block_class): | |
super(ResNet, self).__init__() | |
w = math.sqrt(2) | |
links = [('conv1', L.Convolution2D(3, 64, 7, 2, 0, w))] | |
links += [('bn1', L.BatchNormalization(64))] | |
links += [('_mpool1', F.MaxPooling2D(3, 2, 0, True, True))] | |
for i in range(3): | |
links += [('res{}'.format(len(links)), | |
block_class(256 if i > 0 else 64, 64, 64, 256))] | |
for i in range(8): | |
links += [('res{}'.format(len(links)), | |
block_class(512 if i > 0 else 256, 128, 128, 512, | |
1 if i > 0 else 2))] | |
for i in range(36): | |
links += [('res{}'.format(len(links)), | |
block_class(1024 if i > 0 else 512, 256, 256, 1024, | |
1 if i > 0 else 2))] | |
for i in range(3): | |
links += [('res{}'.format(len(links)), | |
block_class(2048 if i > 0 else 1024, 512, 512, 2048, | |
1 if i > 0 else 2))] | |
links += [('_apool{}'.format(len(links)), | |
F.AveragePooling2D(7, 1, 0, False, True))] | |
links += [('fc{}'.format(len(links)), | |
L.Linear(2048, 1000))] | |
for link in links: | |
if not link[0].startswith('_'): | |
self.add_link(*link) | |
self.forward = links | |
self.train = True | |
def clear(self): | |
self.loss = None | |
self.accuracy = None | |
def __call__(self, x, t=None): | |
self.clear() | |
for name, f in self.forward: | |
if 'res' in name: | |
x = f(x, self.train) | |
else: | |
x = f(x) | |
if t is not None: | |
self.loss = F.softmax_cross_entropy(x, t) | |
self.accuracy = F.accuracy(x, t) | |
return self.loss | |
else: | |
return x | |
model = ResNet(BottleNeck) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import math | |
import chainer | |
import chainer.links as L | |
import chainer.functions as F | |
class BottleNeck(chainer.Chain): | |
def __init__(self, in_size, ch1, ch2, ch3, proj=False, stride=1, ksize=1): | |
w = math.sqrt(2) | |
super(BottleNeck, self).__init__( | |
conv1=L.Convolution2D(in_size, ch1, ksize, stride, 0, w), | |
bn1=L.BatchNormalization(ch1), | |
conv2=L.Convolution2D(ch1, ch2, 3, 1, 1, w), | |
bn2=L.BatchNormalization(ch2), | |
conv3=L.Convolution2D(ch2, ch3, 1, 1, 0, w), | |
bn3=L.BatchNormalization(ch3), | |
) | |
if proj: | |
self.add_link('res', L.Convolution2D( | |
in_size, ch3, 1, stride, 0, w)) | |
def __call__(self, x, train): | |
h = F.relu(self.bn1(self.conv1(x), test=not train)) | |
h = F.relu(self.bn2(self.conv2(h), test=not train)) | |
h = self.bn3(self.conv3(h), test=not train) | |
if 'res' in self._children: | |
x = self.res(x) | |
return F.relu(h + x) | |
class ResNet(chainer.Chain): | |
insize = 224 | |
def __init__(self, block_class): | |
super(ResNet, self).__init__() | |
w = math.sqrt(2) | |
links = [('conv1', L.Convolution2D(3, 64, 7, 2, 0, w))] | |
links += [('bn1', L.BatchNormalization(64))] | |
links += [('_mpool1', F.MaxPooling2D(3, 2, 0, True, True))] | |
for i in range(3): | |
links += [('res{}'.format(len(links)), | |
block_class(256 if i > 0 else 64, 64, 64, 256, | |
False if i > 0 else True))] | |
for i in range(8): | |
links += [('res{}'.format(len(links)), | |
block_class(512 if i > 0 else 256, 128, 128, 512, | |
False if i > 0 else True, | |
1 if i > 0 else 2))] | |
for i in range(36): | |
links += [('res{}'.format(len(links)), | |
block_class(1024 if i > 0 else 512, 256, 256, 1024, | |
False if i > 0 else True, | |
1 if i > 0 else 2))] | |
for i in range(3): | |
links += [('res{}'.format(len(links)), | |
block_class(2048 if i > 0 else 1024, 512, 512, 2048, | |
False if i > 0 else True, | |
1 if i > 0 else 2))] | |
links += [('_apool{}'.format(len(links)), | |
F.AveragePooling2D(7, 1, 0, False, True))] | |
links += [('fc{}'.format(len(links)), | |
L.Linear(2048, 1000))] | |
for link in links: | |
if not link[0].startswith('_'): | |
self.add_link(*link) | |
self.forward = links | |
self.train = True | |
def clear(self): | |
self.loss = None | |
self.accuracy = None | |
def __call__(self, x, t=None): | |
self.clear() | |
for name, f in self.forward: | |
if 'res' in name: | |
x = f(x, self.train) | |
else: | |
x = f(x) | |
if t is not None: | |
self.loss = F.softmax_cross_entropy(x, t) | |
self.accuracy = F.accuracy(x, t) | |
return self.loss | |
else: | |
return x | |
model = ResNet(BottleNeck) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ResNet_A | |
import ResNet_B | |
import numpy as np | |
from chainer import Variable | |
from chainer import computational_graph | |
def forward_test(resnet, graph_fn): | |
x = np.random.rand(1, 3, 224, 224).astype(np.float32) | |
pred = resnet(Variable(x, volatile=False)) | |
with open(graph_fn, 'w') as o: | |
g = computational_graph.build_computational_graph( | |
(pred,), remove_split=True) | |
o.write(g.dump()) | |
return pred | |
resnet = ResNet_A.model | |
ret = forward_test(resnet, 'resnet_A.dot') | |
resnet = ResNet_B.model | |
forward_test(resnet, 'resnet_B.dot') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment