Last active
March 31, 2023 08:22
-
-
Save TerenceLiu98/b31f24dd035a5a0757ba0ded00dc0189 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torchvision | |
import numpy as np | |
## U-Net ## | |
class DualConv(nn.Module): | |
def __init__(self, in_channel, out_channel): | |
super(DualConv, self).__init__() | |
self.conv = nn.Sequential( | |
nn.ReflectionPad2d(padding=2), | |
nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=0), | |
nn.InstanceNorm2d(out_channel, affine=False), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=0), | |
nn.InstanceNorm2d(out_channel, affine=False), | |
nn.ReLU(inplace=True) | |
) | |
def forward(self, x): | |
x = self.conv(x) | |
return x | |
class Encoder(nn.Module): | |
# 3 is for RGB, 1 is for grayscale | |
def __init__(self, channels = [3, 64, 128, 256, 512, 1024]): | |
super(Encoder, self).__init__() | |
self.encblocks = nn.ModuleList([DualConv(channels[i], channels[i+1]) for i in range(len(channels) - 1)]) | |
self.maxpool2d = nn.MaxPool2d(2) | |
def forward(self, x): | |
output = [] | |
for encblock in self.encblocks: | |
x = encblock(x) | |
output.append(x) | |
x = self.maxpool2d(x) | |
return output | |
class Decoder(nn.Module): | |
def __init__(self, channels=[1024, 512, 256, 128, 64], bilinear=False): | |
super().__init__() | |
self.channels = channels | |
if bilinear: | |
self.upsampling = nn.ModuleList([nn.Sequential(nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), | |
nn.Conv2d(channels[i], channels[i+1], kernel_size=1)) for i in range(len(channels)-1)]) | |
else: | |
self.upsampling = nn.ModuleList([nn.ConvTranspose2d(channels[i], channels[i+1], kernel_size=2, stride=2) for i in range(len(channels)-1)]) | |
self.decblocks = nn.ModuleList([DualConv(channels[i], channels[i+1]) for i in range(len(channels)-1)]) | |
def forward(self, x, encoder_features): | |
for i in range(len(self.channels)-1): | |
x = self.upsampling[i](x) | |
enc_ftrs = self.crop(encoder_features[i], x) | |
x = torch.cat([x, enc_ftrs], dim=1) | |
x = self.decblocks[i](x) | |
return x | |
def crop(self, enc_ftrs, x): | |
_, _, H, W = x.shape | |
enc_ftrs = torchvision.transforms.CenterCrop([H, W])(enc_ftrs) | |
return enc_ftrs | |
class UNet(nn.Module): | |
def __init__(self, enc_channels=[3, 64, 128, 256, 512, 1024], dec_channels=[1024, 512, 256, 128, 64], num_class=1, bilinear=False): | |
super(UNet, self).__init__() | |
self.encoder = Encoder(channels=enc_channels) | |
self.decoder = Decoder(channels=dec_channels, bilinear=False) | |
self.output = nn.Sequential(nn.Conv2d(dec_channels[-1], num_class, 1), nn.Sigmoid()) | |
def forward(self, x): | |
enc_out = self.encoder(x) | |
out = self.decoder(enc_out[::-1][0], enc_out[::-1][1:]) | |
out = self.output(out) | |
return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment