Skip to content

Instantly share code, notes, and snippets.

@mihaidusmanu
Created March 29, 2022 12:31
Show Gist options
  • Save mihaidusmanu/440cb01e80b38449a3b1e663d121d5a8 to your computer and use it in GitHub Desktop.
Save mihaidusmanu/440cb01e80b38449a3b1e663d121d5a8 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
class GeM(nn.Module):
def __init__(self, p=3, eps=1e-6):
# Based on https://arxiv.org/abs/1711.02512.
super().__init__()
self.p = torch.nn.parameter.Parameter(torch.ones(1) * p)
self.p.requires_grad = False
self.eps = eps
def forward(self, x):
x = x.clamp(min=eps).pow(p)
dims = tuple(range(x.dim()))[2 :] # remove batch and channels
return x.mean(dims).pow(1. / p)
class ArcClassifier(nn.Module):
def __init__(self, dim, num_classes, margin=0.1, gamma=1.0,
trainable_gamma=True, eps=1e-7):
# Based on https://arxiv.org/abs/1801.07698.
super().__init__()
self.weight = nn.parameter.Parameter(torch.empty([num_classes, dim]))
nn.init.xavier_uniform_(self.weight)
self.margin = margin
self.eps = eps
self.gamma = nn.parameter.Parameter(torch.ones(1) * gamma)
if not trainable_gamma:
self.gamma.requires_grad = False
self.register_parameter('gamma', self.gamma)
def forward(self, x, labels):
raw_logits = F.linear(x, F.normalize(self.weight))
theta = torch.acos(raw_logits.clamp(-1 + self.eps, 1 - self.eps))
# Only apply margin if it lowers the logit.
marginal_target_logits = torch.min(torch.cos(theta + self.margin), raw_logits)
one_hot = F.one_hot(labels, num_classes=raw_logits.size(1)).bool()
final_logits = torch.where(one_hot, marginal_target_logits, raw_logits)
final_logits *= self.gamma
return final_logits
class Network(nn.Module):
def _init(self, output_dim, num_classes):
# Define the GNN architecture here.
self.gnn = nn.Identity()
# Aggregation layer - generalized mean pooling.
self.pool = GeM()
# Classification head.
self.classification_head = ArcClassifier(
output_dim, num_classes
)
def _forward(self, batch):
# Define the GNN forward pass here.
gnn_output = self.gnn(batch)
# Average pooling of GNN output.
# Input should be BxCxN.
# B - batch size, C - descriptor size, N - number of nodes.
# Output is BxC.
global_desc = self.pool(gnn_output)
# L2-normalize global descriptors.
global_desc = F.normalize(global_desc)
pred['global_descriptor'] = global_desc
if 'label' in data:
# ArcFace classifier.
logits = self.classification_head(global_desc, data['label'])
pred['logits'] = logits
return pred
def classification_loss(batch):
pred = self.net(batch)
loss = F.cross_entropy(pred['logits'], batch['label'])
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment