Created
May 28, 2024 13:08
-
-
Save masuidrive/d8e029e479c262f2864b79ae17ca75be to your computer and use it in GitHub Desktop.
MagFaceの実装例 動いているけどアルゴリズム的にあっているかはわからないw
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from pytorch_metric_learning.utils import common_functions as c_f | |
from pytorch_metric_learning.losses.large_margin_softmax_loss import LargeMarginSoftmaxLoss | |
class MagFaceLoss(LargeMarginSoftmaxLoss): | |
""" | |
Implementation of MagFace: https://arxiv.org/pdf/1801.07698.pdf (extension to ArcFace) | |
""" | |
def __init__(self, num_classes, embedding_size, margin=0.5, scale=64, margin_a=0.5, margin_s=0.2, **kwargs): | |
super().__init__( | |
num_classes=num_classes, | |
embedding_size=embedding_size, | |
margin=margin, | |
scale=scale, | |
**kwargs | |
) | |
self.margin_a = margin_a | |
self.margin_s = margin_s | |
self.init_margin() | |
def init_margin(self): | |
self.margin = np.radians(self.margin) | |
def cast_types(self, dtype, device): | |
self.W.data = c_f.to_device(self.W.data, device=device, dtype=dtype) | |
def modify_cosine_of_target_classes(self, cosine_of_target_classes, embeddings): | |
angles = self.get_angles(cosine_of_target_classes) | |
magnitudes = embeddings.norm(dim=1, keepdim=True) | |
# Compute cos(theta + m(||f||)) | |
cos_theta_plus_margin = torch.cos(angles + self.margin * magnitudes) | |
cos_theta = torch.cos(angles) | |
# Compute the MagFace adjustment | |
magface_adjustment = self.margin_a * (magnitudes - self.margin_s) | |
# Keep the cost function monotonically decreasing | |
unscaled_logits = torch.where( | |
angles <= np.deg2rad(180) - self.margin, | |
cos_theta_plus_margin, | |
cos_theta - self.margin * np.sin(self.margin), | |
) - magface_adjustment | |
return unscaled_logits | |
def scale_logits(self, logits, *_): | |
return logits * self.scale | |
def forward(self, embeddings, labels, indices_tuple=None, ref_emb=None, ref_labels=None): | |
# Compute the cosine similarities | |
cosine = self.get_cosine(embeddings) | |
batch_size, num_classes = cosine.size() | |
# Create a mask for the target classes | |
mask = torch.zeros(batch_size, num_classes, dtype=torch.bool, device=cosine.device) | |
mask[torch.arange(batch_size), labels] = 1 | |
# Extract the cosine values for the target classes | |
cosine_of_target_classes = cosine[mask].view(batch_size, -1) | |
modified_cosine_of_target_classes = self.modify_cosine_of_target_classes(cosine_of_target_classes, embeddings).view(batch_size) | |
# Update logits with modified cosine values | |
logits = cosine.clone() | |
logits[mask] = modified_cosine_of_target_classes | |
# Scale the logits | |
scaled_logits = self.scale_logits(logits) | |
# Compute the cross-entropy loss | |
loss = F.cross_entropy(scaled_logits, labels) | |
return loss | |
def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): | |
return self.forward(embeddings, labels, indices_tuple, ref_emb, ref_labels)["loss"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment