Skip to content

Instantly share code, notes, and snippets.

@AlessandroMondin
Last active March 18, 2023 11:48
Show Gist options
  • Save AlessandroMondin/3e63e0e8c44501ccfde540f691da88a5 to your computer and use it in GitHub Desktop.
Save AlessandroMondin/3e63e0e8c44501ccfde540f691da88a5 to your computer and use it in GitHub Desktop.
class Decoder(nn.Module):
"""
Parameters:
in_channels (int): number of in_channels of the first ConvTranspose2d
out_channels (int): number of out_channels of the first ConvTranspose2d
padding (int): padding applied in each convolution
uphill (int): number times a ConvTranspose2d + CNNBlocks it's applied.
"""
def __init__(self,
in_channels,
out_channels,
exit_channels,
padding,
uphill=4):
super(Decoder, self).__init__()
self.exit_channels = exit_channels
self.layers = nn.ModuleList()
for i in range(uphill):
self.layers += [
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
CNNBlocks(n_conv=2, in_channels=in_channels,
out_channels=out_channels, padding=padding),
]
in_channels //= 2
out_channels //= 2
# cannot be a CNNBlock because it has ReLU incorpored
# cannot append nn.Sigmoid here because you should be later using
# BCELoss () which will trigger the amp error "are unsafe to autocast".
self.layers.append(
nn.Conv2d(in_channels, exit_channels, kernel_size=1, padding=padding),
)
def forward(self, x, routes_connection):
# pop the last element of the list since
# it's not used for concatenation
routes_connection.pop(-1)
for layer in self.layers:
if isinstance(layer, CNNBlocks):
# center_cropping the route tensor to make width and height match
routes_connection[-1] = center_crop(routes_connection[-1], x.shape[2])
# concatenating tensors channel-wise
x = torch.cat([x, routes_connection.pop(-1)], dim=1)
x = layer(x)
else:
x = layer(x)
return x
@shahpnmlab
Copy link

Hey! Thanks for writing this tutorial on medium. I found it to be most informative than any other tutorial on the same topic. Just a clarification question though, wouldnt the Decoder block also need a forward method defined on it?

@AlessandroMondin
Copy link
Author

AlessandroMondin commented Mar 18, 2023

Hi @shahpnmlab! Sure it does. I don't know if I haven't included it there to make the gist more "compact" on Medium or if was a typo
Anyway I have updated it now 😀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment