Skip to content

Instantly share code, notes, and snippets.

@AlessandroMondin
Created July 14, 2022 21:38
Show Gist options
  • Save AlessandroMondin/1824a0ad92fb6b9bd9e7eee7a5ade7cc to your computer and use it in GitHub Desktop.
Save AlessandroMondin/1824a0ad92fb6b9bd9e7eee7a5ade7cc to your computer and use it in GitHub Desktop.
class CNNBlocks(nn.Module):
"""
Parameters:
n_conv (int): creates a block of n_conv convolutions
in_channels (int): number of in_channels of the first block's convolution
out_channels (int): number of out_channels of the first block's convolution
expand (bool) : if True after the first convolution of a blocl the number of channels doubles
"""
def __init__(self,
n_conv,
in_channels,
out_channels,
padding):
super(CNNBlocks, self).__init__()
self.layers = nn.ModuleList()
for i in range(n_conv):
self.layers.append(CNNBlock(in_channels, out_channels, padding=padding))
# after each convolution we set (next) in_channel to (previous) out_channels
in_channels = out_channels
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment