-
-
Save shishironline/80fc9a5eca3ffb1811e18b336f9f5ef4 to your computer and use it in GitHub Desktop.
U-Net in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
U-Net architecture in PyTorch (https://arxiv.org/abs/1505.04597) | |
Author: Jacob Reinhold ([email protected]) | |
""" | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class ConvLayer(nn.Sequential): | |
def __init__(self, in_channels:int, out_channels:int): | |
super().__init__() | |
self.add_module('conv', nn.Conv2d(in_channels, out_channels, | |
3, padding=1, bias=False)) | |
self.add_module('norm', nn.BatchNorm2d(out_channels)) | |
self.add_module('relu', nn.ReLU(inplace=True)) | |
class UNetBlock(nn.Sequential): | |
def __init__(self, in_channels:int, out_channels:int): | |
super().__init__() | |
self.add_module('block1', ConvLayer(in_channels, out_channels)) | |
self.add_module('block2', ConvLayer(out_channels, out_channels)) | |
class UNet(nn.Module): | |
def __init__(self, in_channels:int, out_channels:int, channel_base:int=64): | |
super().__init__() | |
self.down_layers = nn.ModuleList([]) | |
n_chan = lambda x: channel_base*2**x | |
self.down_layers.append(UNetBlock(in_channels, n_chan(0))) | |
for i in range(3): | |
self.down_layers.append(UNetBlock(n_chan(i), n_chan(i+1))) | |
self.bottleneck = UNetBlock(n_chan(3), n_chan(4)) | |
self.up_layers = nn.ModuleList([]) | |
for i in reversed(range(1, 4)): | |
self.up_layers.append(UNetBlock(n_chan(i+1)+n_chan(i), n_chan(i))) | |
self.up_layers.append(nn.Sequential( | |
UNetBlock(n_chan(1)+n_chan(0), n_chan(0),), | |
nn.Conv2d(n_chan(0), out_channels, 1))) | |
@staticmethod | |
def interp_cat(x, skip): | |
x = F.interpolate(x, skip.shape[2:], mode='bilinear', align_corners=True) | |
return torch.cat((x, skip), 1) | |
def forward(self, x): | |
skip_connections = [] | |
for down_layer in self.down_layers: | |
x = down_layer(x) | |
skip_connections.append(x) | |
x = F.max_pool2d(x, 2) | |
x = self.bottleneck(x) | |
for up_layer in self.up_layers: | |
skip = skip_connections.pop() | |
x = self.interp_cat(x, skip) | |
x = up_layer(x) | |
return x | |
if __name__ == "__main__": | |
model = UNet(1,1) | |
print(model) | |
x = torch.randn(1,1,128,128) | |
model(x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment