Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save ivanpanshin/1f31980e9c0ede334f91c8dcb13e8402 to your computer and use it in GitHub Desktop.
Save ivanpanshin/1f31980e9c0ede334f91c8dcb13e8402 to your computer and use it in GitHub Desktop.
class ArcModule(nn.Module):
def __init__(self, in_features, out_features, s=10, m=0.5):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.s = s
self.m = m
self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_normal_(self.weight)
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.th = torch.tensor(math.cos(math.pi - m)).half()
self.mm = torch.tensor(math.sin(math.pi - m) * m).half()
# self.th = torch.tensor(math.cos(math.pi - m)).clone().detach().requires_grad_(True)
# self.mm = torch.tensor(math.sin(math.pi - m) * m).clone().detach().requires_grad_(True)
def forward(self, inputs, labels):
cos_th = F.linear(inputs, F.normalize(self.weight))
cos_th = cos_th.clamp(-1, 1)
sin_th = torch.sqrt(1.0 - torch.pow(cos_th, 2))
cos_th_m = (cos_th * self.cos_m - sin_th * self.sin_m).half()
# print(type(cos_th), type(self.th), type(cos_th_m), type(self.mm))
cos_th_m = torch.where(cos_th > self.th, cos_th_m, cos_th - self.mm)
cond_v = cos_th - self.th
cond = cond_v <= 0
cos_th_m[cond] = (cos_th - self.mm)[cond]
if labels.dim() == 1:
labels = labels.unsqueeze(-1)
onehot = torch.zeros(cos_th.size()).cuda()
labels = labels.type(torch.LongTensor).cuda()
onehot.scatter_(1, labels, 1.0)
outputs = onehot * cos_th_m + (1.0 - onehot) * cos_th
outputs = outputs * self.s
return outputs
class ArcFaceModel(nn.Module):
def __init__(self, embed_size, num_classes, dropout_rate):
super(ArcFaceModel, self).__init__()
# self.backbone = torchvision.models.resnet50(pretrained=True)
# self.in_features = self.backbone.fc.in_features
# self.backbone.fc = common_functions.Identity()
self.backbone = timm.create_model('resnet18', pretrained=False)
self.in_features = self.backbone.fc.in_features
self.embed_size = embed_size
self.num_classes = num_classes
self.margin = ArcModule(in_features=self.embed_size, out_features=self.num_classes)
self.bn1 = nn.BatchNorm2d(self.in_features)
self.dropout = nn.Dropout2d(dropout_rate, inplace=True)
self.fc1 = nn.Linear(self.in_features * 16 * 16, self.embed_size)
self.bn2 = nn.BatchNorm1d(self.embed_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment