Skip to content

Instantly share code, notes, and snippets.

@kongzii
Created November 5, 2020 14:38
Show Gist options
  • Save kongzii/0a108b115179cc17d58c158a94465a3c to your computer and use it in GitHub Desktop.
Save kongzii/0a108b115179cc17d58c158a94465a3c to your computer and use it in GitHub Desktop.
Contrastive Loss function in PyTorch
def criterion(x1, x2, label, margin: float = 1.0):
"""
Computes Contrastive Loss
"""
dist = torch.nn.functional.pairwise_distance(x1, x2)
loss = (1 - label) * torch.pow(dist, 2) \
+ (label) * torch.pow(torch.clamp(margin - dist, min=0.0), 2)
loss = torch.mean(loss)
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment