Skip to content

Instantly share code, notes, and snippets.

@level14taken
Created January 4, 2021 07:46
Show Gist options
  • Save level14taken/ed960e1d418bc58e3dda55e588795430 to your computer and use it in GitHub Desktop.
Save level14taken/ed960e1d418bc58e3dda55e588795430 to your computer and use it in GitHub Desktop.
class ChannelAttentionGate(nn.Module):
def __init__(self, channel, reduction=16):
super(ChannelAttentionGate, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return y
class SpatialAttentionGate(nn.Module):
def __init__(self, channel, reduction=16):
super(SpatialAttentionGate, self).__init__()
self.fc1 = nn.Conv2d(channel, reduction, kernel_size=1, padding=0)
self.fc2 = nn.Conv2d(reduction, 1, kernel_size=1, padding=0)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x, inplace=True)
x = self.fc2(x)
x = torch.sigmoid(x)
return x
class Decoder(nn.Module):
def __init__(self, in_channels, channels, out_channels ):
super(Decoder, self).__init__()
self.conv1 = ConvBn2d(in_channels, channels, kernel_size=3, padding=1)
self.conv2 = ConvBn2d(channels, out_channels, kernel_size=3, padding=1)
self.cg= ChannelAttentionGate(out_channels)
self.sg= SpatialAttentionGate(out_channels)
def forward(self, x ):
x = F.upsample(x, scale_factor=2, mode='bilinear', align_corners=True)
x = F.relu(self.conv1(x),inplace=True)
x = F.relu(self.conv2(x),inplace=True)
g1= self.sg(x)
g2= self.cg(x)
x= g1*x+g2*x
return x
class UNetScseHypercol(nn.Module):
def __init__(self ):
super().__init__()
self.resnet = torchvision.models.resnet34(pretrained=True)
self.conv1 = nn.Sequential(
self.resnet.conv1,
self.resnet.bn1,
self.resnet.relu,
)# 64
self.encoder2 = self.resnet.layer1 # 64
self.encoder3 = self.resnet.layer2 #128
self.encoder4 = self.resnet.layer3 #256
self.encoder5 = self.resnet.layer4 #512
self.center = nn.Sequential(
nn.Conv2d(512, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
)
self.decoder5 = Decoder(512+64, 512, 64)
self.decoder4 = Decoder(64+256, 256, 64)
self.decoder3 = Decoder(64+128, 128, 64)
self.decoder2 = Decoder( 64+ 64, 64, 64)
self.logit = nn.Sequential(
nn.Conv2d(256, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(32, 1, kernel_size=1, padding=0),
)
def forward(self, x):
x = self.conv1(x)
e2 = self.encoder2( x) #; print('e2',e2.size())
e3 = self.encoder3(e2) #; print('e3',e3.size())
e4 = self.encoder4(e3) #; print('e4',e4.size())
e5 = self.encoder5(e4) #; print('e5',e5.size())
#; print('center',f.size())
f = self.center(e5)
# print(e5.shape,f.shape)
d5 = self.decoder5(torch.cat([f, e5], 1)) #; print('d5',f.size())
d4 = self.decoder4(torch.cat([d5, e4], 1)) #; print('d4',f.size())
d3= self.decoder3(torch.cat([d4, e3], 1)) #; print('d3',f.size())
d2 = self.decoder2(torch.cat([d3, e2], 1)) #; print('d2',f.size())
##hypercolumn implemented here
f = torch.cat((d2,
F.upsample(d3,scale_factor=2,mode='bilinear',align_corners=False),
F.upsample(d4,scale_factor=4,mode='bilinear',align_corners=False),
F.upsample(d5,scale_factor=8,mode='bilinear',align_corners=False),
),1)
logit = self.logit(f) #; print('logit',logit.size())
return logit
model= UNetScseHypercol()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment