Last active
June 18, 2020 10:46
-
-
Save shreejalt/2c499be21f45ff404f9fe964d24795cb to your computer and use it in GitHub Desktop.
Pytorch implementation to generate different families of RegNet Models(Facebook AI Research: March'2020)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Name: Shreejal Trivedi | |
Description: Generation Script of RegNetX and RegNetY models | |
References: Designing Network Design Spaces from Facebook AI March'2020 | |
''' | |
#Importing Libraries | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
import argparse | |
#Downsampling used in first bottleneck block of every layer in RegNet | |
class Downsample(nn.Module): | |
def __init__(self, in_filters, out_filters, stride): | |
super(Downsample, self).__init__() | |
self.conv1x1 = nn.Conv2d(in_filters, out_filters, kernel_size=1, stride=stride, bias=False) | |
self.bn = nn.BatchNorm2d(out_filters) | |
def forward(self, x): | |
return self.bn(self.conv1x1(x)) | |
#SE Attention Module for RegNetY | |
class SqueezeExcitation(nn.Module): | |
def __init__(self, in_filters, se_ratio): | |
super(SqueezeExcitation, self).__init__() | |
#Calculate bottleneck SE filters | |
out_filters = int(in_filters * se_ratio) | |
#Average Pooling Layer | |
self.avgpool = nn.AdaptiveAvgPool2d(output_size=1) | |
#Squeeze | |
self.conv1_1x1 = nn.Conv2d(in_filters, out_filters, kernel_size=1, bias=True) | |
# Excite | |
self.conv2_1x1 = nn.Conv2d(out_filters, in_filters, kernel_size=1, bias=True) | |
def forward(self, x): | |
out = self.avgpool(x) | |
out = F.relu(self.conv1_1x1(out)) | |
out = self.conv2_1x1(out).sigmoid() | |
out = x * out | |
return out | |
#Bottleneck Residual Block in Layer | |
class Bottleneck(nn.Module): | |
def __init__(self, in_filters, out_filters, bottleneck_ratio, group_size, stride=1, se_ratio=0): | |
super(Bottleneck, self).__init__() | |
#1x1 Bottleneck Convolution Block | |
bottleneck_filters = in_filters // bottleneck_ratio | |
self.conv1_1x1 = nn.Conv2d(in_filters, bottleneck_filters, kernel_size=1, bias=False) | |
self.bn1 = nn.BatchNorm2d(bottleneck_filters) | |
#3x3 Convolution Block with Group Convolutions ---> ResNext alike structure | |
num_groups = bottleneck_filters // group_size | |
self.conv2_3x3 = nn.Conv2d(bottleneck_filters, bottleneck_filters, kernel_size=3, stride=stride, padding=1, groups=num_groups, bias=False) | |
self.bn2 = nn.BatchNorm2d(bottleneck_filters) | |
#Squeeze-Exictation Block: Only for RegNetY | |
self.se_module = SqueezeExcitation(bottleneck_filters, se_ratio) if se_ratio < 1 else None | |
#Downsample if stride=2 | |
self.downsample = Downsample(in_filters, out_filters, stride) if stride != 1 or in_filters != out_filters else None | |
#1x1 Convolution Block | |
self.conv3_1x1 = nn.Conv2d(bottleneck_filters, out_filters, kernel_size=1, bias=False) | |
self.bn3 = nn.BatchNorm2d(out_filters) | |
def forward(self, x): | |
residual = x | |
out = F.relu(self.bn1(self.conv1_1x1(x))) | |
out = F.relu(self.bn2(self.conv2_3x3(out))) | |
if self.se_module is not None: | |
out = self.se_module(out) | |
out = self.bn3(self.conv3_1x1(out)) | |
if self.downsample is not None: | |
residual = self.downsample(x) | |
out += residual | |
out = F.relu(out) | |
return out | |
class Stem(nn.Module): | |
def __init__(self, out_filters, in_filters=3): | |
super(Stem, self).__init__() | |
self.conv3x3 = nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=2, padding=1, bias=False) | |
self.bn = nn.BatchNorm2d(out_filters) | |
def forward(self, x): | |
return F.relu(self.bn(self.conv3x3(x))) | |
class Layer(nn.Module): | |
def __init__(self, in_filters, depth, width, bottleneck_ratio, group_size, se_ratio): | |
super(Layer, self).__init__() | |
self.layers = [] | |
#Total bottleneck blocks in a layer = Depth d | |
for i in range(depth): | |
stride = 2 if i == 0 else 1 | |
bottleneck = Bottleneck(in_filters, width, bottleneck_ratio, group_size, stride, se_ratio) | |
self.layers.append(bottleneck) | |
in_filters = width | |
self.layers = nn.Sequential(*self.layers) | |
def forward(self, x): | |
out = self.layers(x) | |
return out | |
class Head(nn.Module): | |
def __init__(self, in_filters, classes): | |
super(Head, self).__init__() | |
self.avgpool = nn.AdaptiveAvgPool2d(output_size=1) | |
self.fc = nn.Linear(in_filters, classes) | |
def forward(self, x): | |
out = self.avgpool(x) | |
out = torch.flatten(out, 1) | |
out = self.fc(out) | |
return out | |
class RegNet(nn.Module): | |
def __init__(self, paramaters, classes=2): | |
super(RegNet, self).__init__() | |
#Model paramater initialization | |
self.in_filters = 32 | |
self.w, self.d, self.b, self.g, self.se_ratio = parameters | |
self.num_layers = 4 | |
#Stem Part of the generic ResNet/ResNeXt architecture | |
self.stem = Stem(self.in_filters) | |
self.body = [] | |
for i in range(self.num_layers): | |
layer = Layer(self.in_filters, self.d[i], self.w[i], self.b, self.g, self.se_ratio) | |
self.body.append(layer) | |
self.in_filters = self.w[i] | |
#Body Part: Four Layers containing bottleneck residual blocks | |
self.body = nn.Sequential(*self.body) | |
#Head Part: Classification Step FC + AveragePool | |
self.head = Head(self.w[-1], classes) | |
def forward(self, x): | |
out = self.stem(x) | |
out = self.body(out) | |
out = self.head(out) | |
return out | |
def generate_parameters_regnet(D, w0, wa, wm, b, g, q): | |
u = w0 + wa * np.arange(D) # Equation 1 | |
s = np.log(u / w0) / np.log(wm) # Equation 2 | |
s = np.round(s) #Rounding the possible block sizes s | |
w = w0 * np.power(wm, s) # Equation 3 | |
w = np.round(w / 8) * 8 # Make all the width list divisible by 8 | |
w, d = np.unique(w.astype(np.int), return_counts=True) #Finding depth and width lists. | |
gtemp = np.minimum(g, w//b) | |
w = (np.round(w // b / gtemp) * gtemp).astype(int) #To make all the width compatible with group sizes of the 3x3 convolutional layers | |
g = np.unique(gtemp // b)[0] | |
return (w, d, b, g, q) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description="RegNetX | RegNetY Models Generation") | |
parser.add_argument('-D', default=13, type=int, help='Network Depth: Range::[12, 13, ..., 28]') | |
parser.add_argument('-w0', default=24, type=int, help='Initial Width of the First Layer > 0') | |
parser.add_argument('-wa', default=36, type=int, help='Slope Parameter: Range::[0, 1, 2, ..., 255]') | |
parser.add_argument('-wm', default=2.5, type=float, help='Quantization Parameter: Range::[1.5, 3]') | |
parser.add_argument('-b', default=1, type=int, help='Bottleneck Ratio: Range::{1, 2, 4}') | |
parser.add_argument('-g', default=8, type=int, help='Group Size: Range::{1, 2, 4, 8, 16, 32} OR {16, 24, 32, 40, 48, 56, 64}') | |
parser.add_argument('-q', default=1, type=float, help='0 <= SE Ratio < 1') | |
args = parser.parse_args() | |
parameters = generate_parameters_regnet(args.D, args.w0, args.wa, args.wm, args.b, args.g, args.q) | |
model = RegNet(parameters) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment