Skip to content

Instantly share code, notes, and snippets.

@ShigekiKarita
Last active September 2, 2016 14:11
Show Gist options
  • Save ShigekiKarita/3cc3c44b36a12020a3a4e2cf24e5586f to your computer and use it in GitHub Desktop.
Save ShigekiKarita/3cc3c44b36a12020a3a4e2cf24e5586f to your computer and use it in GitHub Desktop.
import chainer
import chainer.functions as F
import chainer.links as L
class DenseNet(chainer.Chain):
"""
https://github.com/liuzhuang13/DenseNet/blob/master/densenet.lua
"""
def __init__(self, output_num, in_shape, depth):
super(DenseNet, self).__init__()
assert (depth - 4) % 3 == 0
self.block_depth = (depth - 4) // 3
self.growth_rate = 12
self.channel_size = 24
self.ksize = (3, 3)
self.in_shape = in_shape
self.depth = depth
self._depth_state = 0
self.transitions = []
self.add_link("cv0", L.Convolution2D(self.in_shape[0], self.channel_size, ksize=self.ksize, stride=1, pad=1))
self.add_block()
self.add_transition()
self.add_block()
self.add_transition()
self.add_block()
self.add_link("bn_last", L.BatchNormalization(self.channel_size))
self.add_link("fc_last", L.Linear(self.channel_size, output_num))
def is_transition(self, d):
return d in self.transitions
def _add_common(self, n_output, k, p):
n_input = self.channel_size
self._depth_state += 1
bn = L.BatchNormalization(n_input)
self.add_link("bn%d" % self._depth_state, bn)
# ResNet initialization
kw, kh = k
n = n_output * kw * kh
W = np.random.randn(n_output, n_input, kw, kh) * ((2.0 / n) ** 0.5)
b = np.zeros((n_output,))
cv = L.Convolution2D(n_input, n_output, ksize=k, stride=1, pad=p, initialW=W, initial_bias=b)
self.add_link("cv%d" % self._depth_state, cv)
def add_block(self):
for i in range(0, self.block_depth):
self._add_common(self.growth_rate, self.ksize, 1)
self.channel_size += self.growth_rate
def add_transition(self):
self._add_common(self.channel_size, (1, 1), 0)
self.transitions.append(self._depth_state)
def __call__(self, data, train=True, drop_rate=0.5):
def common(x, d):
x = self["bn%d" % d](x, test=not train)
x = F.relu(x)
x = self["cv%d" % d](x)
if drop_rate:
x = F.dropout(x, ratio=drop_rate, train=train)
return x
# input layer
x = self.cv0(data)
for d in range(1, self.depth - 1):
if self.is_transition(d):
# transition layer
x = common(x, d)
x = F.average_pooling_2d(x, ksize=(2, 2))
else:
# densely connected block
x_prev = x
x = common(x, d)
x = F.concat((x, x_prev))
# output layer
x = self.bn_last(x, test=not train)
x = F.relu(x)
x = F.average_pooling_2d(x, ksize=(8, 8))
x = self.fc_last(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment