-
-
Save jinglescode/d5296e0bf3ee9b4b135b6f8d9224e1b0 to your computer and use it in GitHub Desktop.
import torch | |
import torch.nn as nn | |
class conv_block(nn.Module): | |
def __init__(self, in_ch, out_ch): | |
super(conv_block, self).__init__() | |
self.conv = nn.Sequential( | |
nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), | |
nn.BatchNorm2d(out_ch), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), | |
nn.BatchNorm2d(out_ch), | |
nn.ReLU(inplace=True)) | |
def forward(self, x): | |
x = self.conv(x) | |
return x | |
class up_conv(nn.Module): | |
def __init__(self, in_ch, out_ch): | |
super(up_conv, self).__init__() | |
self.up = nn.Sequential( | |
nn.Upsample(scale_factor=2), | |
nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), | |
nn.BatchNorm2d(out_ch), | |
nn.ReLU(inplace=True) | |
) | |
def forward(self, x): | |
x = self.up(x) | |
return x | |
class Attention_block(nn.Module): | |
def __init__(self, F_g, F_l, F_int): | |
super(Attention_block, self).__init__() | |
self.W_g = nn.Sequential( | |
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), | |
nn.BatchNorm2d(F_int) | |
) | |
self.W_x = nn.Sequential( | |
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), | |
nn.BatchNorm2d(F_int) | |
) | |
self.psi = nn.Sequential( | |
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), | |
nn.BatchNorm2d(1), | |
nn.Sigmoid() | |
) | |
self.relu = nn.ReLU(inplace=True) | |
def forward(self, g, x): | |
g1 = self.W_g(g) | |
x1 = self.W_x(x) | |
psi = self.relu(g1 + x1) | |
psi = self.psi(psi) | |
out = x * psi | |
return out | |
class UNet_Attention(nn.Module): | |
def __init__(self, img_ch=3, output_ch=1): | |
super(UNet_Attention, self).__init__() | |
n1 = 64 | |
filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] | |
self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.Conv1 = conv_block(img_ch, filters[0]) | |
self.Conv2 = conv_block(filters[0], filters[1]) | |
self.Conv3 = conv_block(filters[1], filters[2]) | |
self.Conv4 = conv_block(filters[2], filters[3]) | |
self.Conv5 = conv_block(filters[3], filters[4]) | |
self.Up5 = up_conv(filters[4], filters[3]) | |
self.Att5 = Attention_block(F_g=filters[3], F_l=filters[3], F_int=filters[2]) | |
self.Up_conv5 = conv_block(filters[4], filters[3]) | |
self.Up4 = up_conv(filters[3], filters[2]) | |
self.Att4 = Attention_block(F_g=filters[2], F_l=filters[2], F_int=filters[1]) | |
self.Up_conv4 = conv_block(filters[3], filters[2]) | |
self.Up3 = up_conv(filters[2], filters[1]) | |
self.Att3 = Attention_block(F_g=filters[1], F_l=filters[1], F_int=filters[0]) | |
self.Up_conv3 = conv_block(filters[2], filters[1]) | |
self.Up2 = up_conv(filters[1], filters[0]) | |
self.Att2 = Attention_block(F_g=filters[0], F_l=filters[0], F_int=32) | |
self.Up_conv2 = conv_block(filters[1], filters[0]) | |
self.Conv = nn.Conv2d(filters[0], output_ch, kernel_size=1, stride=1, padding=0) | |
def forward(self, x): | |
e1 = self.Conv1(x) | |
e2 = self.Maxpool1(e1) | |
e2 = self.Conv2(e2) | |
e3 = self.Maxpool2(e2) | |
e3 = self.Conv3(e3) | |
e4 = self.Maxpool3(e3) | |
e4 = self.Conv4(e4) | |
e5 = self.Maxpool4(e4) | |
e5 = self.Conv5(e5) | |
d5 = self.Up5(e5) | |
x4 = self.Att5(g=d5, x=e4) | |
d5 = torch.cat((x4, d5), dim=1) | |
d5 = self.Up_conv5(d5) | |
d4 = self.Up4(d5) | |
x3 = self.Att4(g=d4, x=e3) | |
d4 = torch.cat((x3, d4), dim=1) | |
d4 = self.Up_conv4(d4) | |
d3 = self.Up3(d4) | |
x2 = self.Att3(g=d3, x=e2) | |
d3 = torch.cat((x2, d3), dim=1) | |
d3 = self.Up_conv3(d3) | |
d2 = self.Up2(d3) | |
x1 = self.Att2(g=d2, x=e1) | |
d2 = torch.cat((x1, d2), dim=1) | |
d2 = self.Up_conv2(d2) | |
out = self.Conv(d2) | |
return out |
Hello @jinglescode,
thanks for your great article and this Gist. It really helped me to understand the original paper.Under which license do you release this Gist?
Would it be okay for you if others use/modify/rewrite it as part of a new open-source project if they mention your work and the original authors of the paper?Thanks and best regards
Hi @maweil! Thanks for reading those articles, I'm glad it helps in some way. If there is anything in there to improve in clarity, please let me know. Feel free to use/modify/rewrite the code for other projects. Definitely do cite the original authors and their paper. I'll glad to be mentioned in your work, but it's optional.
Hi @jinglescode! Thanks a lot for your answer and your permission to build upon your code. The article is a very nice read and I currently can't think of things that need clarification. If something comes to my mind, I'll let you know.
I'll of course mention the original authors and you when publishing my code.
Best regards
Hello @jinglescode,
thanks for your great article and this Gist. It really helped me to understand the original paper.
Under which license do you release this Gist?
Would it be okay for you if others use/modify/rewrite it as part of a new open source project if they mention your work and the original authors of the paper?
Thanks and best regards