Skip to content

Instantly share code, notes, and snippets.

@shishironline
Forked from jcreinhold/unet.py
Created July 30, 2020 04:25
Show Gist options
  • Save shishironline/80fc9a5eca3ffb1811e18b336f9f5ef4 to your computer and use it in GitHub Desktop.
Save shishironline/80fc9a5eca3ffb1811e18b336f9f5ef4 to your computer and use it in GitHub Desktop.
U-Net in PyTorch
"""
U-Net architecture in PyTorch (https://arxiv.org/abs/1505.04597)
Author: Jacob Reinhold ([email protected])
"""
import torch
from torch import nn
from torch.nn import functional as F
class ConvLayer(nn.Sequential):
def __init__(self, in_channels:int, out_channels:int):
super().__init__()
self.add_module('conv', nn.Conv2d(in_channels, out_channels,
3, padding=1, bias=False))
self.add_module('norm', nn.BatchNorm2d(out_channels))
self.add_module('relu', nn.ReLU(inplace=True))
class UNetBlock(nn.Sequential):
def __init__(self, in_channels:int, out_channels:int):
super().__init__()
self.add_module('block1', ConvLayer(in_channels, out_channels))
self.add_module('block2', ConvLayer(out_channels, out_channels))
class UNet(nn.Module):
def __init__(self, in_channels:int, out_channels:int, channel_base:int=64):
super().__init__()
self.down_layers = nn.ModuleList([])
n_chan = lambda x: channel_base*2**x
self.down_layers.append(UNetBlock(in_channels, n_chan(0)))
for i in range(3):
self.down_layers.append(UNetBlock(n_chan(i), n_chan(i+1)))
self.bottleneck = UNetBlock(n_chan(3), n_chan(4))
self.up_layers = nn.ModuleList([])
for i in reversed(range(1, 4)):
self.up_layers.append(UNetBlock(n_chan(i+1)+n_chan(i), n_chan(i)))
self.up_layers.append(nn.Sequential(
UNetBlock(n_chan(1)+n_chan(0), n_chan(0),),
nn.Conv2d(n_chan(0), out_channels, 1)))
@staticmethod
def interp_cat(x, skip):
x = F.interpolate(x, skip.shape[2:], mode='bilinear', align_corners=True)
return torch.cat((x, skip), 1)
def forward(self, x):
skip_connections = []
for down_layer in self.down_layers:
x = down_layer(x)
skip_connections.append(x)
x = F.max_pool2d(x, 2)
x = self.bottleneck(x)
for up_layer in self.up_layers:
skip = skip_connections.pop()
x = self.interp_cat(x, skip)
x = up_layer(x)
return x
if __name__ == "__main__":
model = UNet(1,1)
print(model)
x = torch.randn(1,1,128,128)
model(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment