Skip to content

Instantly share code, notes, and snippets.

@aoirint
Last active March 9, 2020 16:50
Show Gist options
  • Save aoirint/fb04ec16aa47c0f8e4d341b4b12da83d to your computer and use it in GitHub Desktop.
Save aoirint/fb04ec16aa47c0f8e4d341b4b12da83d to your computer and use it in GitHub Desktop.
UNet with Bottom Layer Map scale parameter (only x1, x1/2, x1/2^2, x1/2^3, x1/2^4 support)
# Based on https://github.com/mateuszbuda/brain-segmentation-pytorch/blob/master/unet.py
# Original version is licensed under MIT License by mateuszbuda.
from collections import OrderedDict
import torch
import torch.nn as nn
class UNet(nn.Module):
def __init__(self, in_channels=3, out_channels=1, init_features=32, bottom_inv_scale=5):
super(UNet, self).__init__()
assert bottom_inv_scale in [ 1, 2, 3, 4, 5 ]
K1 = 2 if bottom_inv_scale > 4 else 1
K2 = 2 if bottom_inv_scale > 3 else 1
K3 = 2 if bottom_inv_scale > 2 else 1
K4 = 2 if bottom_inv_scale > 1 else 1
features = init_features
self.encoder1 = UNet._block(in_channels, features, name="enc1")
self.pool1 = nn.MaxPool2d(kernel_size=K1, stride=K1)
self.encoder2 = UNet._block(features, features * 2, name="enc2")
self.pool2 = nn.MaxPool2d(kernel_size=K2, stride=K2)
self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
self.pool3 = nn.MaxPool2d(kernel_size=K3, stride=K3)
self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
self.pool4 = nn.MaxPool2d(kernel_size=K4, stride=K4)
self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")
self.upconv4 = nn.ConvTranspose2d(
features * 16, features * 8, kernel_size=K4, stride=K4
)
self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
self.upconv3 = nn.ConvTranspose2d(
features * 8, features * 4, kernel_size=K3, stride=K3
)
self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
self.upconv2 = nn.ConvTranspose2d(
features * 4, features * 2, kernel_size=K2, stride=K2
)
self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
self.upconv1 = nn.ConvTranspose2d(
features * 2, features, kernel_size=K1, stride=K1
)
self.decoder1 = UNet._block(features * 2, features, name="dec1")
self.conv = nn.Conv2d(
in_channels=features, out_channels=out_channels, kernel_size=1
)
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
print(enc1.shape)
print(enc2.shape)
print(enc3.shape)
print(enc4.shape)
bottleneck = self.bottleneck(self.pool4(enc4))
print('Bottom:', bottleneck.shape)
dec4 = self.upconv4(bottleneck)
print(dec4.shape)
dec4 = torch.cat((dec4, enc4), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
print(dec3.shape)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
print(dec2.shape)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
print(dec1.shape)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.decoder1(dec1)
return torch.sigmoid(self.conv(dec1))
@staticmethod
def _block(in_channels, features, name):
return nn.Sequential(
OrderedDict(
[
(
name + "conv1",
nn.Conv2d(
in_channels=in_channels,
out_channels=features,
kernel_size=3,
padding=1,
bias=False,
),
),
(name + "norm1", nn.BatchNorm2d(num_features=features)),
(name + "relu1", nn.ReLU(inplace=True)),
(
name + "conv2",
nn.Conv2d(
in_channels=features,
out_channels=features,
kernel_size=3,
padding=1,
bias=False,
),
),
(name + "norm2", nn.BatchNorm2d(num_features=features)),
(name + "relu2", nn.ReLU(inplace=True)),
]
)
)
if __name__ == '__main__':
model = UNet(
in_channels=3,
out_channels=1,
bottom_inv_scale=2,
)
model.eval()
image = torch.randn((1, 3, 256, 256))
print('In:', image.shape)
with torch.no_grad():
y = model(image)
print('Out:', y.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment