Skip to content

Instantly share code, notes, and snippets.

@thunderInfy
Created August 7, 2020 14:12
Show Gist options
  • Select an option

  • Save thunderInfy/760ed4f6a611be2870345b90796a2469 to your computer and use it in GitHub Desktop.

Select an option

Save thunderInfy/760ed4f6a611be2870345b90796a2469 to your computer and use it in GitHub Desktop.
τ = 0.05
def loss_function(q, k, queue):
# N is the batch size
N = q.shape[0]
# C is the dimensionality of the representations
C = q.shape[1]
# bmm stands for batch matrix multiplication
# If mat1 is a b×n×m tensor, mat2 is a b×m×p tensor,
# then output will be a b×n×p tensor.
pos = torch.exp(torch.div(torch.bmm(q.view(N,1,C), k.view(N,C,1)).view(N, 1),τ))
# performs matrix multiplication between query and queue tensors
neg = torch.sum(torch.exp(torch.div(torch.mm(q.view(N,C), torch.t(queue)),τ)), dim=1)
# sum is over positive as well as negative samples
denominator = neg + pos
return torch.mean(-torch.log(torch.div(pos,denominator)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment