Created
July 1, 2019 14:43
-
-
Save simgt/ae25d2bb9fca8de18dd2bad5e8904cbb to your computer and use it in GitHub Desktop.
Mobilenetv2 embedding with output normalization
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def conv_bn(inp, oup, stride): | |
return nn.Sequential( | |
nn.Conv2d(inp, oup, 3, stride, 1, bias=False), | |
nn.BatchNorm2d(oup), | |
nn.ReLU6(inplace=True) | |
) | |
def conv_1x1_bn(inp, oup): | |
return nn.Sequential( | |
nn.Conv2d(inp, oup, 1, 1, 0, bias=False), | |
nn.BatchNorm2d(oup), | |
nn.ReLU6(inplace=True) | |
) | |
class InvertedResidual(nn.Module): | |
def __init__(self, inp, oup, stride, expand_ratio): | |
super(InvertedResidual, self).__init__() | |
self.stride = stride | |
assert stride in [1, 2] | |
hidden_dim = round(inp * expand_ratio) | |
self.use_res_connect = self.stride == 1 and inp == oup | |
if expand_ratio == 1: | |
self.conv = nn.Sequential( | |
# dw | |
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, | |
groups=hidden_dim, bias=False), | |
nn.BatchNorm2d(hidden_dim), | |
nn.ReLU6(inplace=True), | |
# pw-linear | |
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), | |
nn.BatchNorm2d(oup), | |
) | |
else: | |
self.conv = nn.Sequential( | |
# pw | |
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), | |
nn.BatchNorm2d(hidden_dim), | |
nn.ReLU6(inplace=True), | |
# dw | |
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, | |
groups=hidden_dim, bias=False), | |
nn.BatchNorm2d(hidden_dim), | |
nn.ReLU6(inplace=True), | |
# pw-linear | |
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), | |
nn.BatchNorm2d(oup), | |
) | |
def forward(self, x): | |
if self.use_res_connect: | |
return x + self.conv(x) | |
else: | |
return self.conv(x) | |
class MobileNetV2(nn.Module): | |
def __init__(self, input_size=224, width_mult=1.): | |
super().__init__() | |
block = InvertedResidual | |
input_channel = 32 | |
last_channel = 1280 | |
interverted_residual_setting = [ | |
# t, c, n, s | |
[1, 16, 1, 1], | |
[6, 24, 2, 2], | |
[6, 32, 3, 2], | |
[6, 64, 4, 2], | |
[6, 96, 3, 1], | |
[6, 160, 3, 2], | |
[6, 320, 1, 1], | |
] | |
# building first layer | |
assert input_size % 32 == 0 | |
input_channel = int(input_channel * width_mult) | |
if width_mult > 1.0: | |
self.last_channel = int(last_channel * width_mult) | |
else: | |
self.last_channel = last_channel | |
self.features = [conv_bn(3, input_channel, 2)] | |
# building inverted residual blocks | |
for t, c, n, s in interverted_residual_setting: | |
output_channel = int(c * width_mult) | |
for i in range(n): | |
if i == 0: | |
self.features.append(block(input_channel, output_channel, s, | |
expand_ratio=t)) | |
else: | |
self.features.append(block(input_channel, output_channel, 1, | |
expand_ratio=t)) | |
input_channel = output_channel | |
# building last several layers | |
self.features.append(conv_1x1_bn(input_channel, self.last_channel)) | |
# make it nn.Sequential | |
self.features = nn.Sequential(*self.features) | |
self._initialize_weights() | |
def _initialize_weights(self): | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
m.weight.data.normal_(0, math.sqrt(2. / n)) | |
if m.bias is not None: | |
m.bias.data.zero_() | |
elif isinstance(m, nn.BatchNorm2d): | |
m.weight.data.fill_(1) | |
m.bias.data.zero_() | |
def forward(self, x): | |
x = self.features(x) | |
x = x.mean(3).mean(2) | |
x = F.normalize(x, dim=1) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment