Skip to content

Instantly share code, notes, and snippets.

@axel-angel
Last active May 4, 2018 01:04
Show Gist options
  • Save axel-angel/c2b2943ead94c200574a to your computer and use it in GitHub Desktop.
Save axel-angel/c2b2943ead94c200574a to your computer and use it in GitHub Desktop.
Caffe Python layer for Contrastive Loss
import caffe
import numpy as np
import os
import sys
# Author: Axel Angel, copyright 2015, license GPLv3.
class OwnContrastiveLossLayer(caffe.Layer):
def setup(self, bottom, top):
# check input pair
if len(bottom) != 3:
raise Exception("Need two inputs to compute distance.")
def reshape(self, bottom, top):
# check input dimensions match
if bottom[0].count != bottom[1].count:
raise Exception("Inputs must have the same dimension.")
# difference is shape of inputs
self.diff = np.zeros(bottom[0].num, dtype=np.float32)
self.dist_sq = np.zeros(bottom[0].num, dtype=np.float32)
self.zeros = np.zeros(bottom[0].num)
self.m = 1.0
# loss output is scalar
top[0].reshape(1)
def forward(self, bottom, top):
GW1 = bottom[0].data
GW2 = bottom[1].data
Y = bottom[2].data
loss = 0.0
self.diff = GW1 - GW2
self.dist_sq = np.sum(self.diff**2, axis=1)
losses = Y * self.dist_sq \
+ (1-Y) * np.max([self.zeros, self.m - self.dist_sq], axis=0)
loss = np.sum(losses)
top[0].data[0] = loss / 2.0 / bottom[0].num
def backward(self, top, propagate_down, bottom):
Y = bottom[2].data
disClose = np.where(self.m - self.dist_sq > 0.0, 1.0, 0.0)
for i, sign in enumerate([ +1, -1 ]):
if propagate_down[i]:
alphas = np.where(Y > 0, +1.0, -1.0) * sign * top[0].diff[0] / bottom[i].num
facts = ((1-Y) * disClose + Y) * alphas
bottom[i].diff[...] = np.array([facts, facts]).T * self.diff
@pvskand
Copy link

pvskand commented Jun 29, 2017

You have calculated the loss as losses = Y * self.dist_sq + (1-Y) * np.max([self.zeros, self.m - self.dist_sq], axis=0)
but according to the (paper)[http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf] shouldn't it be
losses = (1-Y) * self.dist_sq + Y * np.max([self.zeros, self.m - self.dist_sq], axis=0) i.e reversing the Y and (1-Y)?

@bhavyagoyal
Copy link

@pvskand The difference is because this code assumes Y[i]=1 for similar pairs (and Y[i]=0 for dissimilar pairs) whereas the paper uses the reverse notation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment