Skip to content

Instantly share code, notes, and snippets.

@felixgwu
Created April 2, 2017 05:22
Show Gist options
  • Save felixgwu/045c887b6ccdf0edf4648da0c40bcc12 to your computer and use it in GitHub Desktop.
Save felixgwu/045c887b6ccdf0edf4648da0c40bcc12 to your computer and use it in GitHub Desktop.
FC-DenseNet Implementation in PyTorch
import torch
from torch import nn
__all__ = ['FCDenseNet', 'fcdensenet_tiny', 'fcdensenet56_nodrop',
'fcdensenet56', 'fcdensenet67', 'fcdensenet103',
'fcdensenet103_nodrop']
class DenseBlock(nn.Module):
def __init__(self, nIn, growth_rate, depth, drop_rate=0, only_new=False,
bottle_neck=False):
super(DenseBlock, self).__init__()
self.only_new = only_new
self.depth = depth
self.growth_rate = growth_rate
self.layers = nn.ModuleList([self.get_transform(
nIn + i * growth_rate, growth_rate, bottle_neck,
drop_rate) for i in range(depth)])
def forward(self, x):
if self.only_new:
outputs = []
for i in range(self.depth):
tx = self.layers[i](x)
x = torch.cat((x, tx), 1)
outputs.append(tx)
return torch.cat(outputs, 1)
else:
for i in range(self.depth):
x = torch.cat((x, self.layers[i](x)), 1)
return x
def get_transform(self, nIn, nOut, bottle_neck=None, drop_rate=0):
if not bottle_neck or nIn <= nOut * bottle_neck:
return nn.Sequential(
nn.BatchNorm2d(nIn),
nn.ReLU(True),
nn.Conv2d(nIn, nOut, 3, stride=1, padding=1, bias=True),
nn.Dropout(drop_rate),
)
else:
nBottle = nOut * bottle_neck
return nn.Sequential(
nn.BatchNorm2d(nIn),
nn.ReLU(True),
nn.Conv2d(nIn, nBottle, 1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(nBottle),
nn.ReLU(True),
nn.Conv2d(nBottle, nOut, 3, stride=1, padding=1, bias=True),
nn.Dropout(drop_rate),
)
class FCDenseNet(nn.Module):
def __init__(self, depths, growth_rates, n_scales=5, n_channel_start=48,
n_classes=12, drop_rate=0, bottle_neck=False):
super(FCDenseNet, self).__init__()
self.n_scales = n_scales
self.n_classes = n_classes
self.n_channel_start = n_channel_start
self.depths = [depths] * \
(2 * n_scales + 1) if type(depths) == int else depths
self.growth_rates = [growth_rates] * (2 * n_scales + 1) if \
type(growth_rates) == int else growth_rates
self.drop_rate = drop_rate
assert len(self.depths) == len(self.growth_rates) == 2 * n_scales + 1
self.conv_first = nn.Conv2d(
3, n_channel_start, 3, stride=1, padding=1, bias=True)
self.dense_blocks = nn.ModuleList([])
self.transition_downs = nn.ModuleList([])
self.transition_ups = nn.ModuleList([])
nskip = []
nIn = self.n_channel_start
for i in range(n_scales):
self.dense_blocks.append(
DenseBlock(nIn, self.growth_rates[i], self.depths[i],
drop_rate=drop_rate, bottle_neck=bottle_neck))
nIn += self.growth_rates[i] * self.depths[i]
nskip.append(nIn)
self.transition_downs.append(self.get_TD(nIn, drop_rate))
self.dense_blocks.append(
DenseBlock(nIn, self.growth_rates[n_scales], self.depths[n_scales],
only_new=True, drop_rate=drop_rate,
bottle_neck=bottle_neck))
nIn = self.growth_rates[n_scales] * self.depths[n_scales]
for i in range(n_scales-1):
self.transition_ups.append(nn.ConvTranspose2d(
nIn, nIn, 3, stride=2, padding=1, bias=True))
nIn += nskip.pop()
self.dense_blocks.append(
DenseBlock(nIn, self.growth_rates[n_scales + 1 + i],
self.depths[n_scales + 1 + i],
only_new=True, drop_rate=drop_rate,
bottle_neck=bottle_neck))
nIn = self.growth_rates[n_scales + 1 + i] * \
self.depths[n_scales + 1 + i]
# last dense block
self.transition_ups.append(nn.ConvTranspose2d(
nIn, nIn, 3, stride=2, padding=1, bias=True))
nIn += nskip.pop()
self.dense_blocks.append(
DenseBlock(nIn, self.growth_rates[2 * n_scales],
self.depths[2 * n_scales], drop_rate=drop_rate,
bottle_neck=bottle_neck))
nIn += self.growth_rates[2 * n_scales] * \
self.depths[2 * n_scales]
self.conv_last = nn.Conv2d(nIn, n_classes, 1, bias=True)
self.logsoftmax = nn.LogSoftmax()
def forward(self, x):
x = self.conv_first(x)
skip_connects = []
# down sample
for i in range(self.n_scales):
x = self.dense_blocks[i](x)
skip_connects.append(x)
x = self.transition_downs[i](x)
# bottle neck
x = self.dense_blocks[self.n_scales](x)
# up sample
for i in range(self.n_scales):
skip = skip_connects.pop()
TU = self.transition_ups[i]
# adjust padding
TU.padding = (((x.size(2) - 1) * TU.stride[0] - skip.size(2)
+ TU.kernel_size[0] + 1) // 2,
((x.size(3) - 1) * TU.stride[1] - skip.size(3)
+ TU.kernel_size[1] + 1) // 2)
x = TU(x, output_size=skip.size())
x = torch.cat((skip, x), 1)
x = self.dense_blocks[self.n_scales + 1 + i](x)
x = self.conv_last(x)
return self.logsoftmax(x)
def get_TD(self, nIn, drop_rate):
layers = [nn.BatchNorm2d(nIn), nn.ReLU(
True), nn.Conv2d(nIn, nIn, 1, bias=True)]
if drop_rate > 0:
layers.append(nn.Dropout(drop_rate))
layers.append(nn.MaxPool2d(2))
return nn.Sequential(*layers)
def fcdensenet_tiny(drop_rate=0):
return FCDenseNet(2, 6, drop_rate=drop_rate)
def fcdensenet56_nodrop():
return FCDenseNet(4, 12, drop_rate=0)
def fcdensenet56(drop_rate=0.2):
return FCDenseNet(4, 12, drop_rate=drop_rate)
def fcdensenet67(drop_rate=0.2):
return FCDenseNet(5, 16, drop_rate=drop_rate)
def fcdensenet103(drop_rate=0.2):
return FCDenseNet([4, 5, 7, 10, 12, 15, 12, 10, 7, 5, 4], 16,
drop_rate=drop_rate)
def fcdensenet103_nodrop(drop_rate=0):
return FCDenseNet([4, 5, 7, 10, 12, 15, 12, 10, 7, 5, 4], 16,
drop_rate=drop_rate)
@hmishfaq
Copy link

hmishfaq commented Aug 10, 2017

@felixgwu, What does "n_scales" refer to in FCDenseNet?

@youngfly11
Copy link

@hmishfaq , 'N_scales' represents the number of DB block in downsample path

@youngfly11
Copy link

@fexigwu, would you share your dataLoader in this project?

@Emma20102
Copy link

hi
is their any contribution that i can do with this architecture to improve its performance? Many thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment