Skip to content

Instantly share code, notes, and snippets.

@jinglescode
Created December 6, 2019 14:02
Show Gist options
  • Save jinglescode/d5296e0bf3ee9b4b135b6f8d9224e1b0 to your computer and use it in GitHub Desktop.
Save jinglescode/d5296e0bf3ee9b4b135b6f8d9224e1b0 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
class conv_block(nn.Module):
def __init__(self, in_ch, out_ch):
super(conv_block, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True))
def forward(self, x):
x = self.conv(x)
return x
class up_conv(nn.Module):
def __init__(self, in_ch, out_ch):
super(up_conv, self).__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.up(x)
return x
class Attention_block(nn.Module):
def __init__(self, F_g, F_l, F_int):
super(Attention_block, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
out = x * psi
return out
class UNet_Attention(nn.Module):
def __init__(self, img_ch=3, output_ch=1):
super(UNet_Attention, self).__init__()
n1 = 64
filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.Conv1 = conv_block(img_ch, filters[0])
self.Conv2 = conv_block(filters[0], filters[1])
self.Conv3 = conv_block(filters[1], filters[2])
self.Conv4 = conv_block(filters[2], filters[3])
self.Conv5 = conv_block(filters[3], filters[4])
self.Up5 = up_conv(filters[4], filters[3])
self.Att5 = Attention_block(F_g=filters[3], F_l=filters[3], F_int=filters[2])
self.Up_conv5 = conv_block(filters[4], filters[3])
self.Up4 = up_conv(filters[3], filters[2])
self.Att4 = Attention_block(F_g=filters[2], F_l=filters[2], F_int=filters[1])
self.Up_conv4 = conv_block(filters[3], filters[2])
self.Up3 = up_conv(filters[2], filters[1])
self.Att3 = Attention_block(F_g=filters[1], F_l=filters[1], F_int=filters[0])
self.Up_conv3 = conv_block(filters[2], filters[1])
self.Up2 = up_conv(filters[1], filters[0])
self.Att2 = Attention_block(F_g=filters[0], F_l=filters[0], F_int=32)
self.Up_conv2 = conv_block(filters[1], filters[0])
self.Conv = nn.Conv2d(filters[0], output_ch, kernel_size=1, stride=1, padding=0)
def forward(self, x):
e1 = self.Conv1(x)
e2 = self.Maxpool1(e1)
e2 = self.Conv2(e2)
e3 = self.Maxpool2(e2)
e3 = self.Conv3(e3)
e4 = self.Maxpool3(e3)
e4 = self.Conv4(e4)
e5 = self.Maxpool4(e4)
e5 = self.Conv5(e5)
d5 = self.Up5(e5)
x4 = self.Att5(g=d5, x=e4)
d5 = torch.cat((x4, d5), dim=1)
d5 = self.Up_conv5(d5)
d4 = self.Up4(d5)
x3 = self.Att4(g=d4, x=e3)
d4 = torch.cat((x3, d4), dim=1)
d4 = self.Up_conv4(d4)
d3 = self.Up3(d4)
x2 = self.Att3(g=d3, x=e2)
d3 = torch.cat((x2, d3), dim=1)
d3 = self.Up_conv3(d3)
d2 = self.Up2(d3)
x1 = self.Att2(g=d2, x=e1)
d2 = torch.cat((x1, d2), dim=1)
d2 = self.Up_conv2(d2)
out = self.Conv(d2)
return out
@maweil
Copy link

maweil commented May 23, 2020

Hello @jinglescode,
thanks for your great article and this Gist. It really helped me to understand the original paper.

Under which license do you release this Gist?
Would it be okay for you if others use/modify/rewrite it as part of a new open source project if they mention your work and the original authors of the paper?

Thanks and best regards

@jinglescode
Copy link
Author

Hello @jinglescode,
thanks for your great article and this Gist. It really helped me to understand the original paper.

Under which license do you release this Gist?
Would it be okay for you if others use/modify/rewrite it as part of a new open-source project if they mention your work and the original authors of the paper?

Thanks and best regards

Hi @maweil! Thanks for reading those articles, I'm glad it helps in some way. If there is anything in there to improve in clarity, please let me know. Feel free to use/modify/rewrite the code for other projects. Definitely do cite the original authors and their paper. I'll glad to be mentioned in your work, but it's optional.

@maweil
Copy link

maweil commented May 25, 2020

Hi @jinglescode! Thanks a lot for your answer and your permission to build upon your code. The article is a very nice read and I currently can't think of things that need clarification. If something comes to my mind, I'll let you know.
I'll of course mention the original authors and you when publishing my code.

Best regards

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment