Created
April 15, 2021 19:12
-
-
Save rwightman/f8b24f4e6f5504aba03e999e02460d31 to your computer and use it in GitHub Desktop.
An example U-Net using timm features_only functionality.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" A simple U-Net w/ timm backbone encoder | |
Based off an old version of Unet in https://github.com/qubvel/segmentation_models.pytorch | |
Hacked together by Ross Wightman | |
""" | |
from typing import Optional, List | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from timm import create_model | |
class Unet(nn.Module): | |
"""Unet is a fully convolution neural network for image semantic segmentation | |
Args: | |
encoder_name: name of classification model (without last dense layers) used as feature | |
extractor to build segmentation model. | |
encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). | |
decoder_channels: list of numbers of ``Conv2D`` layer filters in decoder blocks | |
decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers | |
is used. | |
num_classes: a number of classes for output (output shape - ``(batch, classes, h, w)``). | |
center: if ``True`` add ``Conv2dReLU`` block on encoder head | |
NOTE: This is based off an old version of Unet in https://github.com/qubvel/segmentation_models.pytorch | |
""" | |
def __init__( | |
self, | |
backbone='resnet50', | |
backbone_kwargs=None, | |
backbone_indices=None, | |
decoder_use_batchnorm=True, | |
decoder_channels=(256, 128, 64, 32, 16), | |
in_chans=1, | |
num_classes=5, | |
center=False, | |
norm_layer=nn.BatchNorm2d, | |
): | |
super().__init__() | |
backbone_kwargs = backbone_kwargs or {} | |
# NOTE some models need different backbone indices specified based on the alignment of features | |
# and some models won't have a full enough range of feature strides to work properly. | |
encoder = create_model( | |
backbone, features_only=True, out_indices=backbone_indices, in_chans=in_chans, | |
pretrained=True, **backbone_kwargs) | |
encoder_channels = encoder.feature_info.channels()[::-1] | |
self.encoder = encoder | |
if not decoder_use_batchnorm: | |
norm_layer = None | |
self.decoder = UnetDecoder( | |
encoder_channels=encoder_channels, | |
decoder_channels=decoder_channels, | |
final_channels=num_classes, | |
norm_layer=norm_layer, | |
center=center, | |
) | |
def forward(self, x: torch.Tensor): | |
x = self.encoder(x) | |
x.reverse() # torchscript doesn't work with [::-1] | |
x = self.decoder(x) | |
return x | |
class Conv2dBnAct(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, padding=0, | |
stride=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): | |
super().__init__() | |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False) | |
self.bn = norm_layer(out_channels) | |
self.act = act_layer(inplace=True) | |
def forward(self, x): | |
x = self.conv(x) | |
x = self.bn(x) | |
x = self.act(x) | |
return x | |
class DecoderBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, scale_factor=2.0, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): | |
super().__init__() | |
conv_args = dict(kernel_size=3, padding=1, act_layer=act_layer) | |
self.scale_factor = scale_factor | |
if norm_layer is None: | |
self.conv1 = Conv2dBnAct(in_channels, out_channels, **conv_args) | |
self.conv2 = Conv2dBnAct(out_channels, out_channels, **conv_args) | |
else: | |
self.conv1 = Conv2dBnAct(in_channels, out_channels, norm_layer=norm_layer, **conv_args) | |
self.conv2 = Conv2dBnAct(out_channels, out_channels, norm_layer=norm_layer, **conv_args) | |
def forward(self, x, skip: Optional[torch.Tensor] = None): | |
if self.scale_factor != 1.0: | |
x = F.interpolate(x, scale_factor=self.scale_factor, mode='nearest') | |
if skip is not None: | |
x = torch.cat([x, skip], dim=1) | |
x = self.conv1(x) | |
x = self.conv2(x) | |
return x | |
class UnetDecoder(nn.Module): | |
def __init__( | |
self, | |
encoder_channels, | |
decoder_channels=(256, 128, 64, 32, 16), | |
final_channels=1, | |
norm_layer=nn.BatchNorm2d, | |
center=False, | |
): | |
super().__init__() | |
if center: | |
channels = encoder_channels[0] | |
self.center = DecoderBlock(channels, channels, scale_factor=1.0, norm_layer=norm_layer) | |
else: | |
self.center = nn.Identity() | |
in_channels = [in_chs + skip_chs for in_chs, skip_chs in zip( | |
[encoder_channels[0]] + list(decoder_channels[:-1]), | |
list(encoder_channels[1:]) + [0])] | |
out_channels = decoder_channels | |
self.blocks = nn.ModuleList() | |
for in_chs, out_chs in zip(in_channels, out_channels): | |
self.blocks.append(DecoderBlock(in_chs, out_chs, norm_layer=norm_layer)) | |
self.final_conv = nn.Conv2d(out_channels[-1], final_channels, kernel_size=(1, 1)) | |
self._init_weight() | |
def _init_weight(self): | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
torch.nn.init.kaiming_normal_(m.weight) | |
elif isinstance(m, nn.BatchNorm2d): | |
m.weight.data.fill_(1) | |
m.bias.data.zero_() | |
def forward(self, x: List[torch.Tensor]): | |
encoder_head = x[0] | |
skips = x[1:] | |
x = self.center(encoder_head) | |
for i, b in enumerate(self.blocks): | |
skip = skips[i] if i < len(skips) else None | |
x = b(x, skip) | |
x = self.final_conv(x) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment