Created
March 29, 2022 12:31
-
-
Save mihaidusmanu/440cb01e80b38449a3b1e663d121d5a8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class GeM(nn.Module): | |
def __init__(self, p=3, eps=1e-6): | |
# Based on https://arxiv.org/abs/1711.02512. | |
super().__init__() | |
self.p = torch.nn.parameter.Parameter(torch.ones(1) * p) | |
self.p.requires_grad = False | |
self.eps = eps | |
def forward(self, x): | |
x = x.clamp(min=eps).pow(p) | |
dims = tuple(range(x.dim()))[2 :] # remove batch and channels | |
return x.mean(dims).pow(1. / p) | |
class ArcClassifier(nn.Module): | |
def __init__(self, dim, num_classes, margin=0.1, gamma=1.0, | |
trainable_gamma=True, eps=1e-7): | |
# Based on https://arxiv.org/abs/1801.07698. | |
super().__init__() | |
self.weight = nn.parameter.Parameter(torch.empty([num_classes, dim])) | |
nn.init.xavier_uniform_(self.weight) | |
self.margin = margin | |
self.eps = eps | |
self.gamma = nn.parameter.Parameter(torch.ones(1) * gamma) | |
if not trainable_gamma: | |
self.gamma.requires_grad = False | |
self.register_parameter('gamma', self.gamma) | |
def forward(self, x, labels): | |
raw_logits = F.linear(x, F.normalize(self.weight)) | |
theta = torch.acos(raw_logits.clamp(-1 + self.eps, 1 - self.eps)) | |
# Only apply margin if it lowers the logit. | |
marginal_target_logits = torch.min(torch.cos(theta + self.margin), raw_logits) | |
one_hot = F.one_hot(labels, num_classes=raw_logits.size(1)).bool() | |
final_logits = torch.where(one_hot, marginal_target_logits, raw_logits) | |
final_logits *= self.gamma | |
return final_logits | |
class Network(nn.Module): | |
def _init(self, output_dim, num_classes): | |
# Define the GNN architecture here. | |
self.gnn = nn.Identity() | |
# Aggregation layer - generalized mean pooling. | |
self.pool = GeM() | |
# Classification head. | |
self.classification_head = ArcClassifier( | |
output_dim, num_classes | |
) | |
def _forward(self, batch): | |
# Define the GNN forward pass here. | |
gnn_output = self.gnn(batch) | |
# Average pooling of GNN output. | |
# Input should be BxCxN. | |
# B - batch size, C - descriptor size, N - number of nodes. | |
# Output is BxC. | |
global_desc = self.pool(gnn_output) | |
# L2-normalize global descriptors. | |
global_desc = F.normalize(global_desc) | |
pred['global_descriptor'] = global_desc | |
if 'label' in data: | |
# ArcFace classifier. | |
logits = self.classification_head(global_desc, data['label']) | |
pred['logits'] = logits | |
return pred | |
def classification_loss(batch): | |
pred = self.net(batch) | |
loss = F.cross_entropy(pred['logits'], batch['label']) | |
return loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment