Created
December 28, 2018 03:26
-
-
Save tuan3w/9c4b308b1f3d250a036104c9c31d449a to your computer and use it in GitHub Desktop.
peleenet.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch.nn import Parameter | |
import math | |
import torch.nn.functional as F | |
class Scale(nn.Module): | |
def __init__(self, channels): | |
super(Scale, self).__init__() | |
self.weight = Parameter(torch.Tensor(channels)) | |
self.bias = Parameter(torch.Tensor(channels)) | |
self.channels = channels | |
def forward(self, x): | |
nB = x.size(0) | |
nC = x.size(1) | |
nH = x.size(2) | |
nW = x.size(3) | |
x = x * self.weight.view(1, nC, 1, 1).expand(nB, nC, nH, nW) + \ | |
self.bias.view(1, nC, 1, 1).expand(nB, nC, nH, nW) | |
return x | |
def __repr__(self): | |
return 'Scale(channels=%d)' % self.channels | |
class StemConv(nn.Module): | |
def __init__(self, inp, oup, kernel_size=3, stride=1, pad=1,use_relu = True): | |
super(StemConv, self).__init__() | |
self.use_relu = use_relu | |
if self.use_relu: | |
self.c = nn.Conv2d(inp, oup, kernel_size, stride, pad, bias=False ) | |
self.bn = nn.BatchNorm2d(oup, affine=False) | |
self.act = nn.ReLU(inplace=True) | |
else: | |
self.c = nn.Conv2d(inp, oup, kernel_size, stride, pad, bias=False), | |
self.bn = nn.BatchNorm2d(oup, affine=False) | |
self.act = None | |
self.scale = Scale(oup) | |
def forward(self, x): | |
x = self.bn(self.c(x)) | |
if self.act: | |
x = self.act(x) | |
x = self.scale(x) | |
return x | |
class ConvNorm(nn.Module): | |
def __init__(self, inp, oup, kernel_size=3, stride=1, pad=1,use_relu = True): | |
super(ConvNorm, self).__init__() | |
self.use_relu = use_relu | |
if self.use_relu: | |
self.c = nn.Conv2d(inp, oup, kernel_size, stride, pad, bias=False ) | |
self.bn = nn.BatchNorm2d(oup, affine=False) | |
self.act = nn.ReLU(inplace=True) | |
else: | |
self.c = nn.Conv2d(inp, oup, kernel_size, stride, pad, bias=False), | |
self.bn = nn.BatchNorm2d(oup, affine=False) | |
self.act = None | |
def forward(self, x): | |
x = self.bn(self.c(x)) | |
if self.act: | |
x = self.act(x) | |
return x | |
class StemBlock(nn.Module): | |
def __init__(self, inp=3,num_init_features=32): | |
super(StemBlock, self).__init__() | |
self.stem1 = StemConv(inp, num_init_features, 3, 2, 1) | |
self.stem2a = StemConv(num_init_features,int(num_init_features/2),1,1,0) | |
self.stem2b = StemConv(int(num_init_features/2), num_init_features, 3, 2, 1) | |
self.stem2p = nn.MaxPool2d(kernel_size=2,stride=2) | |
self.stem3 = StemConv(num_init_features*2,num_init_features,1,1,0) | |
def forward(self, x): | |
stem_1_out = self.stem1(x) | |
stem_2a_out = self.stem2a(stem_1_out) | |
stem_2b_out = self.stem2b(stem_2a_out) | |
stem_2p_out = self.stem2p(stem_1_out) | |
out = self.stem3(torch.cat((stem_2b_out,stem_2p_out),1)) | |
return out | |
class DenseBlock(nn.Module): | |
def __init__(self, inp,inter_channel,growth_rate): | |
super(DenseBlock, self).__init__() | |
self.branch1a = StemConv(inp,inter_channel,1,1,0) | |
self.branch1b = StemConv(inter_channel,growth_rate,3,1,1) | |
self.branch2a = StemConv(inp,inter_channel,1,1,0) | |
self.branch2b = StemConv(inter_channel,growth_rate,3,1,1) | |
self.branch2c = StemConv(growth_rate,growth_rate,3,1,1) | |
def forward(self, x): | |
cb1_a_out = self.branch1a(x) | |
cb1_b_out = self.branch1b(cb1_a_out) | |
cb2_a_out = self.branch2a(x) | |
cb2_b_out = self.branch2b(cb2_a_out) | |
cb2_c_out = self.branch2c(cb2_b_out) | |
out = torch.cat((x,cb1_b_out,cb2_c_out),1) | |
return out | |
class TransitionBlock(nn.Module): | |
def __init__(self, inp, oup,with_pooling= True): | |
super(TransitionBlock, self).__init__() | |
if with_pooling: | |
self.tb = nn.Sequential(StemConv(inp,oup,1,1,0), | |
nn.AvgPool2d(kernel_size=2,stride=2)) | |
else: | |
self.tb = StemConv(inp,oup,1,1,0) | |
def forward(self, x): | |
out = self.tb(x) | |
return out | |
class PeleeNet(nn.Module): | |
def __init__(self,num_classes=1000, num_init_features=32,growthRate=32, nDenseBlocks = [3,4,8,6], bottleneck_width=[1,2,4,4]): | |
super(PeleeNet, self).__init__() | |
self.stage = nn.Sequential() | |
self.num_classes = num_classes | |
self.num_init_features = num_init_features | |
inter_channel =list() | |
total_filter =list() | |
dense_inp = list() | |
self.half_growth_rate = int(growthRate / 2) | |
# building stemblock | |
self.stage.add_module('stage_0', StemBlock(3,num_init_features)) | |
# | |
for i, b_w in enumerate(bottleneck_width): | |
inter_channel.append(int(self.half_growth_rate * b_w / 4) * 4) | |
if i == 0: | |
total_filter.append(num_init_features + growthRate * nDenseBlocks[i]) | |
dense_inp.append(self.num_init_features) | |
else: | |
total_filter.append(total_filter[i-1] + growthRate * nDenseBlocks[i]) | |
dense_inp.append(total_filter[i-1]) | |
if i == len(nDenseBlocks)-1: | |
with_pooling = False | |
else: | |
with_pooling = True | |
# building middle stageblock | |
self.stage.add_module('stage{}'.format(i+1),self._make_dense_transition(dense_inp[i], total_filter[i], | |
inter_channel[i],nDenseBlocks[i],with_pooling=with_pooling)) | |
# building classifier | |
self.classifier = nn.Sequential( | |
nn.Dropout(), | |
nn.Linear(total_filter[len(nDenseBlocks)-1], self.num_classes) | |
) | |
# self._initialize_weights() | |
def _make_dense_transition(self, dense_inp,total_filter, inter_channel, ndenseblocks,with_pooling= True): | |
layers = [] | |
for i in range(ndenseblocks): | |
layers.append(DenseBlock(dense_inp, inter_channel,self.half_growth_rate)) | |
dense_inp += self.half_growth_rate * 2 | |
#Transition Layer without Compression | |
layers.append(TransitionBlock(dense_inp,total_filter,with_pooling)) | |
return nn.Sequential(*layers) | |
def forward(self, x): | |
x = self.stage(x) | |
# global average pooling layer | |
x = F.avg_pool2d(x,kernel_size=7) | |
x = x.view(x.size(0), -1) | |
x = self.classifier(x) | |
out = F.log_softmax(x,dim=1) | |
return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment