Skip to content

Instantly share code, notes, and snippets.

@calebrob6
Created March 13, 2025 03:26
Show Gist options
  • Save calebrob6/3606673c0ba4b31173b4ae93474b48fe to your computer and use it in GitHub Desktop.
Save calebrob6/3606673c0ba4b31173b4ae93474b48fe to your computer and use it in GitHub Desktop.
A subclass of torchgeo's RCF model that averages features over a given mask instead of the whole input.
class RCFSegmentationFeatures(RCF):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Forward pass of the RCF model.
Args:
x: a tensor with shape (C, H, W)
y: a tensor with shape (H, W)
Returns:
a tensor of size (``self.num_features``)
"""
x1a = F.relu(
F.conv2d(x, self.weights, bias=self.biases, stride=1, padding=0),
inplace=True,
)
x1b = F.relu(
-F.conv2d(x, self.weights, bias=self.biases, stride=1, padding=0),
inplace=False,
)
padding = int(np.floor(self.weights.shape[-1] / 2))
y = y[padding:-padding, padding:-padding]
x1a = x1a[:, y == 1].mean(dim=1)
x1b = x1b[:, y == 1].mean(dim=1)
return torch.cat((x1a, x1b), dim=0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment