Last active
August 16, 2018 10:06
-
-
Save wdevazelhes/8c9f5fdc53ed6e7bbe8d8f958351db85 to your computer and use it in GitHub Desktop.
Code for comparing two implementations of the gradient for MLKR
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from metric_learn import MLKR | |
from sklearn.utils import check_random_state | |
import numpy as np | |
from losses import _loss_non_optimized, _loss_optimized | |
from collections import defaultdict | |
from sklearn.datasets import make_regression | |
for n_features in [5, 100]: | |
print('n_features={}'.format(n_features)) | |
X, y = make_regression(n_features=n_features) | |
for seed in range(5): | |
rng = check_random_state(seed) | |
A = rng.randn(X.shape[0], X.shape[0]) | |
print('gradient differences:') | |
print(np.linalg.norm(_loss_optimized(A, X, y)[1] | |
- _loss_non_optimized(A, X, y)[1])) | |
print('loss differences:') | |
print(_loss_optimized(A, X, y)[0] | |
- _loss_non_optimized(A, X, y)[0]) | |
# Printing the whole values (with less features for better visualisation) | |
X, y = make_regression(n_features=5) | |
for seed in range(5): | |
rng = check_random_state(seed) | |
A = rng.randn(2, 5) | |
for loss in [_loss_non_optimized, _loss_optimized]: | |
print(loss(A, X, y)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.datasets import make_regression | |
from metric_learn import MLKR | |
import numpy as np | |
from scipy.spatial.distance import pdist, squareform | |
from scipy.special import logsumexp | |
def _loss_optimized(flatA, X, y): | |
A = flatA.reshape((-1, X.shape[1])) | |
dist = pdist(X, metric='mahalanobis', VI=A.T.dot(A)) | |
dist = squareform(dist ** 2) | |
np.fill_diagonal(dist, np.inf) | |
softmax = np.exp(- dist - logsumexp(- dist, axis=1)[:, np.newaxis]) | |
yhat = softmax.dot(y) | |
ydiff = yhat - y | |
cost = (ydiff**2).sum() | |
# also compute the gradient | |
W = softmax * ydiff[:, np.newaxis] * (yhat[:, np.newaxis] - y) | |
X_emb_t = A.dot(X.T) | |
grad = (4 * (X_emb_t * (W.sum(axis=0)) | |
- X_emb_t.dot(W + W.T)).dot(X)) | |
return cost, grad.ravel() | |
def _loss_non_optimized(flatA, X, y): | |
dX = (X[None] - X[:, None]).reshape((-1, X.shape[1])) | |
A = flatA.reshape((-1, X.shape[1])) | |
dist = pdist(X, metric='mahalanobis', VI=A.T.dot(A)) | |
dist = squareform(dist ** 2) | |
np.fill_diagonal(dist, np.inf) | |
softmax = np.exp(- dist - logsumexp(- dist, axis=1)[:, np.newaxis]) | |
yhat = softmax.dot(y) | |
ydiff = yhat - y | |
cost = (ydiff**2).sum() | |
# also compute the gradient | |
W = 2 * softmax * ydiff[:, np.newaxis] * (yhat[:, np.newaxis] - y) | |
# note: this is the part that the matlab impl drops to C for | |
M = (dX.T * W.ravel()).dot(dX) | |
grad = 2 * A.dot(M) | |
return cost, grad.ravel() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
n_features=5 | |
gradient differences: | |
3.5345143236136213e-07 | |
loss differences: | |
0.0 | |
gradient differences: | |
1.4108694814314438e-07 | |
loss differences: | |
0.0 | |
gradient differences: | |
4.480441650559029e-07 | |
loss differences: | |
0.0 | |
gradient differences: | |
3.808513898809513e-07 | |
loss differences: | |
0.0 | |
gradient differences: | |
5.433925727007268e-07 | |
loss differences: | |
0.0 | |
n_features=100 | |
gradient differences: | |
3.959822089906634e-05 | |
loss differences: | |
0.0 | |
gradient differences: | |
5.3507102171022624e-09 | |
loss differences: | |
0.0 | |
gradient differences: | |
0.00012580672385627842 | |
loss differences: | |
0.0 | |
gradient differences: | |
4.469019658445259e-05 | |
loss differences: | |
0.0 | |
gradient differences: | |
5.81210044182099e-06 | |
loss differences: | |
0.0 | |
(236393.51759114384, array([-219161.66316797, -170897.25052117, 59911.3633348 , | |
186149.63871032, 5270.58498449, 53025.7184015 , | |
64104.16815048, -80517.43248341, -40525.74730135, | |
40761.10429498])) | |
(236393.51759114384, array([-219161.66316797, -170897.25052117, 59911.3633348 , | |
186149.63871032, 5270.58498449, 53025.7184015 , | |
64104.16815048, -80517.43248341, -40525.74730135, | |
40761.10429498])) | |
(1501224.4227846859, array([-209351.22776078, -260401.58473759, 6864.50714577, | |
-227578.33921967, 172301.175685 , -140104.31132398, | |
80865.16959862, 103477.93303428, 123817.35534745, | |
428969.55979609])) | |
(1501224.4227846859, array([-209351.22776078, -260401.58473759, 6864.50714577, | |
-227578.33921967, 172301.175685 , -140104.31132398, | |
80865.16959862, 103477.93303428, 123817.35534746, | |
428969.55979609])) | |
(795249.879923564, array([ 46840.66296907, -75064.23955227, -31242.95390164, | |
7520.43747259, -27340.04496518, 671234.20586862, | |
456469.43030066, -241225.58944759, -17310.93917806, | |
113111.51540722])) | |
(795249.879923564, array([ 46840.66296907, -75064.23955227, -31242.95390164, | |
7520.43747259, -27340.04496518, 671234.20586862, | |
456469.43030066, -241225.58944759, -17310.93917806, | |
113111.51540722])) | |
(924194.0322471606, array([ 205955.58356486, -34547.5281954 , -65518.95190382, | |
162702.50449867, 91425.93225032, 801867.80294049, | |
158816.43702189, -397855.89401635, 780588.87373477, | |
326040.71320266])) | |
(924194.0322471606, array([ 205955.58356486, -34547.5281954 , -65518.95190383, | |
162702.50449868, 91425.93225032, 801867.80294049, | |
158816.43702189, -397855.89401635, 780588.87373476, | |
326040.71320266])) | |
(814925.4406213816, array([ -5983.30201068, -35876.07751133, -37340.34433632, 96135.50074118, | |
9191.34576283, 142508.76621226, -62013.01269183, 377700.75180631, | |
691628.55589295, 114744.10319056])) | |
(814925.4406213816, array([ -5983.30201068, -35876.07751133, -37340.34433632, 96135.50074118, | |
9191.34576283, 142508.76621226, -62013.01269183, 377700.75180631, | |
691628.55589295, 114744.10319056])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment