Skip to content

Instantly share code, notes, and snippets.

@TerenceLiu98
Last active March 31, 2023 08:22
Show Gist options
  • Save TerenceLiu98/b31f24dd035a5a0757ba0ded00dc0189 to your computer and use it in GitHub Desktop.
Save TerenceLiu98/b31f24dd035a5a0757ba0ded00dc0189 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torchvision
import numpy as np
## U-Net ##
class DualConv(nn.Module):
def __init__(self, in_channel, out_channel):
super(DualConv, self).__init__()
self.conv = nn.Sequential(
nn.ReflectionPad2d(padding=2),
nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=0),
nn.InstanceNorm2d(out_channel, affine=False),
nn.ReLU(inplace=True),
nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=0),
nn.InstanceNorm2d(out_channel, affine=False),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class Encoder(nn.Module):
# 3 is for RGB, 1 is for grayscale
def __init__(self, channels = [3, 64, 128, 256, 512, 1024]):
super(Encoder, self).__init__()
self.encblocks = nn.ModuleList([DualConv(channels[i], channels[i+1]) for i in range(len(channels) - 1)])
self.maxpool2d = nn.MaxPool2d(2)
def forward(self, x):
output = []
for encblock in self.encblocks:
x = encblock(x)
output.append(x)
x = self.maxpool2d(x)
return output
class Decoder(nn.Module):
def __init__(self, channels=[1024, 512, 256, 128, 64], bilinear=False):
super().__init__()
self.channels = channels
if bilinear:
self.upsampling = nn.ModuleList([nn.Sequential(nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(channels[i], channels[i+1], kernel_size=1)) for i in range(len(channels)-1)])
else:
self.upsampling = nn.ModuleList([nn.ConvTranspose2d(channels[i], channels[i+1], kernel_size=2, stride=2) for i in range(len(channels)-1)])
self.decblocks = nn.ModuleList([DualConv(channels[i], channels[i+1]) for i in range(len(channels)-1)])
def forward(self, x, encoder_features):
for i in range(len(self.channels)-1):
x = self.upsampling[i](x)
enc_ftrs = self.crop(encoder_features[i], x)
x = torch.cat([x, enc_ftrs], dim=1)
x = self.decblocks[i](x)
return x
def crop(self, enc_ftrs, x):
_, _, H, W = x.shape
enc_ftrs = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
return enc_ftrs
class UNet(nn.Module):
def __init__(self, enc_channels=[3, 64, 128, 256, 512, 1024], dec_channels=[1024, 512, 256, 128, 64], num_class=1, bilinear=False):
super(UNet, self).__init__()
self.encoder = Encoder(channels=enc_channels)
self.decoder = Decoder(channels=dec_channels, bilinear=False)
self.output = nn.Sequential(nn.Conv2d(dec_channels[-1], num_class, 1), nn.Sigmoid())
def forward(self, x):
enc_out = self.encoder(x)
out = self.decoder(enc_out[::-1][0], enc_out[::-1][1:])
out = self.output(out)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment