Skip to content

Instantly share code, notes, and snippets.

@nkt1546789
Last active August 29, 2016 14:29
Show Gist options
  • Save nkt1546789/be19e0fe94e3599df81f83e78482786b to your computer and use it in GitHub Desktop.
Save nkt1546789/be19e0fe94e3599df81f83e78482786b to your computer and use it in GitHub Desktop.
(low-rank kernel) Information theoretic metric leanring
import numpy as np
import scipy.sparse.linalg as sla
def KITML_lr(K0, constraints, dm=None, dc=None, gamma=1., max_iter=1000, stop_threshold=1e-3, max_k=None):
# check if K0 is symmetric
if max_k is None:
max_k = K0.shape[0]-1
S, U = sla.eigsh(K0, k=max_k)
U = U[:,::-1]
S = S[::-1]
for k in xrange(max_k):
if (np.sum(S[:k])/np.sum(S))>=0.99:
break
Phi = np.dot(U[:,:k],np.diag(np.sqrt(S[:k])))
print "k:",k
A = ITML(Phi,constraints,dm=dm,dc=dc,gamma=gamma,max_iter=max_iter,stop_threshold=stop_threshold)
K = Phi.dot(A.dot(Phi.T))
return K
def ITML(X, constraints, dm=None, dc=None, gamma=1.0, max_iter=1000, stop_threshold=1e-3):
n,d=X.shape
X2=np.c_[np.sum(X**2,axis=1)]
dist2=X2+X2.T-2*X.dot(X.T)
dist2=dist2[np.tril_indices(n,-1)]
if dm is None:
dm=np.min([0.05,np.percentile(dist2,1)])
if dc is None:
dc=np.max([1.95,np.percentile(dist2,99)])
print "dm:{0}, dc:{1}".format(dm,dc)
Xi={}; Lambda={}; A=np.identity(X.shape[1]);
for iteration in xrange(max_iter):
Aold=A.copy()
updates=0. # the number of updates. this should converge to 0
for delta,i,j in constraints:
i=int(i); j=int(j);
p=(X[i]-X[j]).dot(A).dot((X[i]-X[j]))
if delta==1:
Xi.setdefault((i,j),dm);
if p<=Xi[(i,j)]: # if the must-link constraint is already satisfied
continue
else:
Xi.setdefault((i,j),dc);
if p>=Xi[(i,j)]: # if the cannot-link constraint is already satisfied
continue
Lambda.setdefault((i,j),0.)
if p==0:
continue
alpha=min(Lambda[(i,j)],(delta/2.)*(1./p-gamma/Xi[(i,j)]))
if alpha==0:
continue
updates+=1
beta=(delta*alpha)/(1.-delta*alpha*p)
de = gamma+delta*alpha*Xi[(i,j)]
Xi[(i,j)]=(gamma*Xi[(i,j)])/de
Lambda[(i,j)]=Lambda[(i,j)]-alpha
xij=np.c_[X[i]-X[j]]
A=A+beta*A.dot(xij.dot(xij.T)).dot(A)
print "number of updates:",updates
diff=np.sqrt(np.sum((A-Aold)**2))
#print diff
#if updates < 600:
if diff<stop_threshold:
print "converged at {0} steps".format(iteration)
break
return A
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment