Skip to content

Instantly share code, notes, and snippets.

@AlessandroMondin
Created December 8, 2022 12:45
Show Gist options
  • Save AlessandroMondin/1e0fa411afe0aef00668ec7faf41b953 to your computer and use it in GitHub Desktop.
Save AlessandroMondin/1e0fa411afe0aef00668ec7faf41b953 to your computer and use it in GitHub Desktop.
class C3(nn.Module):
"""
Parameters:
in_channels (int): number of channel of the input tensor
out_channels (int): number of channel of the output tensor
width_multiple (float): it controls the number of channels (and weights)
of all the convolutions beside the
first and last one. If closer to 0,
the simpler the modelIf closer to 1,
the model becomes more complex
depth (int): it controls the number of times the bottleneck (residual block)
is repeated within the C3 block
backbone (bool): if True, self.seq will be composed by bottlenecks 1, if False
it will be composed by bottlenecks 2 (check in the image linked below)
https://user-images.githubusercontent.com/31005897/172404576-c260dcf9-76bb-4bc8-b6a9-f2d987792583.png
"""
def __init__(self, in_channels, out_channels, width_multiple=1, depth=1, backbone=True):
super(C3, self).__init__()
c_ = int(width_multiple*in_channels)
self.c1 = CBL(in_channels, c_, kernel_size=1, stride=1, padding=0)
self.c_skipped = CBL(in_channels, c_, kernel_size=1, stride=1, padding=0)
if backbone:
self.seq = nn.Sequential(
*[Bottleneck(c_, c_, width_multiple=1) for _ in range(depth)]
)
else:
self.seq = nn.Sequential(
*[nn.Sequential(
CBL(c_, c_, 1, 1, 0),
CBL(c_, c_, 3, 1, 1)
) for _ in range(depth)]
)
self.c_out = CBL(c_ * 2, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x = torch.cat([self.seq(self.c1(x)), self.c_skipped(x)], dim=1)
return self.c_out(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment