Last active
March 9, 2020 16:50
-
-
Save aoirint/fb04ec16aa47c0f8e4d341b4b12da83d to your computer and use it in GitHub Desktop.
UNet with Bottom Layer Map scale parameter (only x1, x1/2, x1/2^2, x1/2^3, x1/2^4 support)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Based on https://github.com/mateuszbuda/brain-segmentation-pytorch/blob/master/unet.py | |
# Original version is licensed under MIT License by mateuszbuda. | |
from collections import OrderedDict | |
import torch | |
import torch.nn as nn | |
class UNet(nn.Module): | |
def __init__(self, in_channels=3, out_channels=1, init_features=32, bottom_inv_scale=5): | |
super(UNet, self).__init__() | |
assert bottom_inv_scale in [ 1, 2, 3, 4, 5 ] | |
K1 = 2 if bottom_inv_scale > 4 else 1 | |
K2 = 2 if bottom_inv_scale > 3 else 1 | |
K3 = 2 if bottom_inv_scale > 2 else 1 | |
K4 = 2 if bottom_inv_scale > 1 else 1 | |
features = init_features | |
self.encoder1 = UNet._block(in_channels, features, name="enc1") | |
self.pool1 = nn.MaxPool2d(kernel_size=K1, stride=K1) | |
self.encoder2 = UNet._block(features, features * 2, name="enc2") | |
self.pool2 = nn.MaxPool2d(kernel_size=K2, stride=K2) | |
self.encoder3 = UNet._block(features * 2, features * 4, name="enc3") | |
self.pool3 = nn.MaxPool2d(kernel_size=K3, stride=K3) | |
self.encoder4 = UNet._block(features * 4, features * 8, name="enc4") | |
self.pool4 = nn.MaxPool2d(kernel_size=K4, stride=K4) | |
self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck") | |
self.upconv4 = nn.ConvTranspose2d( | |
features * 16, features * 8, kernel_size=K4, stride=K4 | |
) | |
self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4") | |
self.upconv3 = nn.ConvTranspose2d( | |
features * 8, features * 4, kernel_size=K3, stride=K3 | |
) | |
self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3") | |
self.upconv2 = nn.ConvTranspose2d( | |
features * 4, features * 2, kernel_size=K2, stride=K2 | |
) | |
self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2") | |
self.upconv1 = nn.ConvTranspose2d( | |
features * 2, features, kernel_size=K1, stride=K1 | |
) | |
self.decoder1 = UNet._block(features * 2, features, name="dec1") | |
self.conv = nn.Conv2d( | |
in_channels=features, out_channels=out_channels, kernel_size=1 | |
) | |
def forward(self, x): | |
enc1 = self.encoder1(x) | |
enc2 = self.encoder2(self.pool1(enc1)) | |
enc3 = self.encoder3(self.pool2(enc2)) | |
enc4 = self.encoder4(self.pool3(enc3)) | |
print(enc1.shape) | |
print(enc2.shape) | |
print(enc3.shape) | |
print(enc4.shape) | |
bottleneck = self.bottleneck(self.pool4(enc4)) | |
print('Bottom:', bottleneck.shape) | |
dec4 = self.upconv4(bottleneck) | |
print(dec4.shape) | |
dec4 = torch.cat((dec4, enc4), dim=1) | |
dec4 = self.decoder4(dec4) | |
dec3 = self.upconv3(dec4) | |
print(dec3.shape) | |
dec3 = torch.cat((dec3, enc3), dim=1) | |
dec3 = self.decoder3(dec3) | |
dec2 = self.upconv2(dec3) | |
print(dec2.shape) | |
dec2 = torch.cat((dec2, enc2), dim=1) | |
dec2 = self.decoder2(dec2) | |
dec1 = self.upconv1(dec2) | |
print(dec1.shape) | |
dec1 = torch.cat((dec1, enc1), dim=1) | |
dec1 = self.decoder1(dec1) | |
return torch.sigmoid(self.conv(dec1)) | |
@staticmethod | |
def _block(in_channels, features, name): | |
return nn.Sequential( | |
OrderedDict( | |
[ | |
( | |
name + "conv1", | |
nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=features, | |
kernel_size=3, | |
padding=1, | |
bias=False, | |
), | |
), | |
(name + "norm1", nn.BatchNorm2d(num_features=features)), | |
(name + "relu1", nn.ReLU(inplace=True)), | |
( | |
name + "conv2", | |
nn.Conv2d( | |
in_channels=features, | |
out_channels=features, | |
kernel_size=3, | |
padding=1, | |
bias=False, | |
), | |
), | |
(name + "norm2", nn.BatchNorm2d(num_features=features)), | |
(name + "relu2", nn.ReLU(inplace=True)), | |
] | |
) | |
) | |
if __name__ == '__main__': | |
model = UNet( | |
in_channels=3, | |
out_channels=1, | |
bottom_inv_scale=2, | |
) | |
model.eval() | |
image = torch.randn((1, 3, 256, 256)) | |
print('In:', image.shape) | |
with torch.no_grad(): | |
y = model(image) | |
print('Out:', y.shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment