Skip to content

Instantly share code, notes, and snippets.

@AlessandroMondin
Last active July 20, 2022 12:56
Show Gist options
  • Save AlessandroMondin/fe2d7d5a814dff4a45072942958f6c7c to your computer and use it in GitHub Desktop.
Save AlessandroMondin/fe2d7d5a814dff4a45072942958f6c7c to your computer and use it in GitHub Desktop.
class UNET(nn.Module):
def __init__(self,
in_channels,
first_out_channels,
exit_channels,
downhill,
padding=0
):
super(UNET, self).__init__()
self.encoder = Encoder(in_channels, first_out_channels, padding=padding, downhill=downhill)
self.decoder = Decoder(first_out_channels*(2**downhill), first_out_channels*(2**(downhill-1)),
exit_channels, padding=padding, uphill=downhill)
def forward(self, x):
enc_out, routes = self.encoder(x)
out = self.decoder(enc_out, routes)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment