Skip to content

Instantly share code, notes, and snippets.

@mkocabas
Last active June 18, 2018 09:13
Show Gist options
  • Save mkocabas/2f3c561abd925ff7cae0308ced0a5aa2 to your computer and use it in GitHub Desktop.
Save mkocabas/2f3c561abd925ff7cae0308ced0a5aa2 to your computer and use it in GitHub Desktop.
from __future__ import print_function, absolute_import
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
class FPN(nn.Module):
def __init__(self,
backbone='resnet50',
feature_size=256,
kernel_size=3,
use_bias=True):
super(FPN, self).__init__()
self.fs = feature_size
self.ks = kernel_size
self.backbone_c5_dim = 2048
self.backbone_dims = [self.backbone_c5_dim,
self.backbone_c5_dim / 2,
self.backbone_c5_dim / 4,
self.backbone_c5_dim / 8]
print('Building %s model'%backbone)
self.backbone = models.__dict__[backbone](pretrained=True)
self.lateral = nn.ModuleList()
# Lateral connections of FPN
for d in self.backbone_dims:
self.lateral.append(
nn.Conv2d(d, self.fs, kernel_size=1, stride=1, padding=0)
)
# Top connections of FPN
self.top = nn.ModuleList()
for _ in self.backbone_dims:
self.top.append(
nn.Conv2d(self.fs, self.fs, kernel_size=3, stride=1, padding=1)
)
self.relu = nn.ReLU()
self._initialize(self.top, bias=use_bias)
self._initialize(self.lateral, bias=use_bias)
def _initialize(self, modules, bias=True):
for param in modules:
if isinstance(param, nn.Conv2d):
nn.init.xavier_normal(param.weight)
if bias:
nn.init.constant(param.bias, 0.0)
def _upsample(self, x, y):
_, _, H, W = y.size()
return F.upsample(x, size=(H, W), mode='bilinear')
def forward(self, input):
x = self.backbone.conv1(input)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
c1 = self.backbone.maxpool(x)
c2 = self.backbone.layer1(c1)
c3 = self.backbone.layer2(c2)
c4 = self.backbone.layer3(c3)
c5 = self.backbone.layer4(c4)
c = [c5, c4, c3, c2] # These are the intermediate outputs of backbone => stride 2^n
# FPN P-layers
p = []
p_up = None
for i in range(4):
_p = self.lateral[i](c[i])
_p = self.relu(_p)
if i > 0:
_p = p_up + _p
if i < len(c) - 1:
p_up = self._upsample(_p, c[i + 1])
_p = self.top[i](_p)
p.append(_p)
return p
def fpn(weights, **kwargs):
model = FPN(**kwargs)
if weights:
model.load_state_dict(torch.load(weights)['state_dict'])
return model
if __name__ == '__main__':
model = FPN()
x = torch.autograd.Variable(torch.Tensor(4, 3, 256, 256))
out = model(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment