Skip to content

Instantly share code, notes, and snippets.

@AlessandroMondin
Last active December 10, 2022 14:45
Show Gist options
  • Save AlessandroMondin/a9ab6d2c8c0fd86c029dd8998c81e025 to your computer and use it in GitHub Desktop.
Save AlessandroMondin/a9ab6d2c8c0fd86c029dd8998c81e025 to your computer and use it in GitHub Desktop.
def forward(self, x):
assert x.shape[2] % 32 == 0 and x.shape[3] % 32 == 0, "Width and Height aren't divisible by 32!"
backbone_connection = []
neck_connection = []
outputs = []
for idx, layer in enumerate(self.backbone):
# takes the out of the 2nd and 3rd C3 block and stores it
x = layer(x)
if idx in [4, 6]:
backbone_connection.append(x)
for idx, layer in enumerate(self.neck):
if idx in [0, 2]:
x = layer(x)
neck_connection.append(x)
x = Resize([x.shape[2]*2, x.shape[3]*2], interpolation=InterpolationMode.NEAREST)(x)
x = torch.cat([x, backbone_connection.pop(-1)], dim=1)
elif idx in [4, 6]:
x = layer(x)
x = torch.cat([x, neck_connection.pop(-1)], dim=1)
elif isinstance(layer, C3) and idx > 2:
x = layer(x)
outputs.append(x)
else:
x = layer(x)
return self.head(outputs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment