Skip to content

Instantly share code, notes, and snippets.

@masuidrive
Created May 28, 2024 13:08
Show Gist options
  • Save masuidrive/d8e029e479c262f2864b79ae17ca75be to your computer and use it in GitHub Desktop.
Save masuidrive/d8e029e479c262f2864b79ae17ca75be to your computer and use it in GitHub Desktop.
MagFaceの実装例 動いているけどアルゴリズム的にあっているかはわからないw
import numpy as np
import torch
import torch.nn.functional as F
from pytorch_metric_learning.utils import common_functions as c_f
from pytorch_metric_learning.losses.large_margin_softmax_loss import LargeMarginSoftmaxLoss
class MagFaceLoss(LargeMarginSoftmaxLoss):
"""
Implementation of MagFace: https://arxiv.org/pdf/1801.07698.pdf (extension to ArcFace)
"""
def __init__(self, num_classes, embedding_size, margin=0.5, scale=64, margin_a=0.5, margin_s=0.2, **kwargs):
super().__init__(
num_classes=num_classes,
embedding_size=embedding_size,
margin=margin,
scale=scale,
**kwargs
)
self.margin_a = margin_a
self.margin_s = margin_s
self.init_margin()
def init_margin(self):
self.margin = np.radians(self.margin)
def cast_types(self, dtype, device):
self.W.data = c_f.to_device(self.W.data, device=device, dtype=dtype)
def modify_cosine_of_target_classes(self, cosine_of_target_classes, embeddings):
angles = self.get_angles(cosine_of_target_classes)
magnitudes = embeddings.norm(dim=1, keepdim=True)
# Compute cos(theta + m(||f||))
cos_theta_plus_margin = torch.cos(angles + self.margin * magnitudes)
cos_theta = torch.cos(angles)
# Compute the MagFace adjustment
magface_adjustment = self.margin_a * (magnitudes - self.margin_s)
# Keep the cost function monotonically decreasing
unscaled_logits = torch.where(
angles <= np.deg2rad(180) - self.margin,
cos_theta_plus_margin,
cos_theta - self.margin * np.sin(self.margin),
) - magface_adjustment
return unscaled_logits
def scale_logits(self, logits, *_):
return logits * self.scale
def forward(self, embeddings, labels, indices_tuple=None, ref_emb=None, ref_labels=None):
# Compute the cosine similarities
cosine = self.get_cosine(embeddings)
batch_size, num_classes = cosine.size()
# Create a mask for the target classes
mask = torch.zeros(batch_size, num_classes, dtype=torch.bool, device=cosine.device)
mask[torch.arange(batch_size), labels] = 1
# Extract the cosine values for the target classes
cosine_of_target_classes = cosine[mask].view(batch_size, -1)
modified_cosine_of_target_classes = self.modify_cosine_of_target_classes(cosine_of_target_classes, embeddings).view(batch_size)
# Update logits with modified cosine values
logits = cosine.clone()
logits[mask] = modified_cosine_of_target_classes
# Scale the logits
scaled_logits = self.scale_logits(logits)
# Compute the cross-entropy loss
loss = F.cross_entropy(scaled_logits, labels)
return loss
def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
return self.forward(embeddings, labels, indices_tuple, ref_emb, ref_labels)["loss"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment