Skip to content

Instantly share code, notes, and snippets.

@wdevazelhes
Created May 16, 2018 16:45
Show Gist options
  • Save wdevazelhes/3c349e13976613d15ebc46178c942474 to your computer and use it in GitHub Desktop.
Save wdevazelhes/3c349e13976613d15ebc46178c942474 to your computer and use it in GitHub Desktop.
import numpy as np
from scipy.sparse.csgraph import laplacian
from sklearn.utils import check_random_state
from scipy.sparse import coo_matrix
from numpy.testing import assert_array_almost_equal
RNG = check_random_state(0)
def test_loss_sdml():
n_samples = 10
n_dims = 5
X = RNG.randn(n_samples, n_dims)
c = np.array([[1, 2], [3, 6], [2, 3], [4, 2], [5, 3]])
n_pairs = len(c)
pairs = X[c]
y = RNG.choice([-1, 1], (n_pairs))
# original computation of the expression
adj = coo_matrix((y, (c[:, 0], c[:, 1])), shape=(n_samples,) * 2)
adj_sym = adj + adj.T
L = laplacian(adj_sym, normed=False)
expr_1 = X.T.dot(L.dot(X))
# equivalent way to do it with already formed pairs
diff = pairs[:, 0] - pairs[:, 1]
expr_2 = (diff.T * y).dot(diff)
assert_array_almost_equal(expr_1, expr_2)
def test_loss_sdml_with_duplicates():
n_samples = 10
n_dims = 5
X = RNG.randn(n_samples, n_dims)
c = np.array([[1, 2], [3, 6], [2, 3], [4, 2], [5, 3], [5, 3]])
pairs = X[c]
y = [-1, 1, 1, 1, 1, 1]
# original computation of the expression
adj = coo_matrix((y, (c[:, 0], c[:, 1])), shape=(n_samples,) * 2)
adj_sym = adj + adj.T
L = laplacian(adj_sym, normed=False)
expr_1 = X.T.dot(L.dot(X))
# equivalent way to do it with pairs already formed
diff = pairs[:, 0] - pairs[:, 1]
expr_2 = (diff.T * y).dot(diff)
assert_array_almost_equal(expr_1, expr_2)
if __name__ == '__main__':
test_loss_sdml()
test_loss_sdml_with_duplicates()
@wdevazelhes
Copy link
Author

This snippet tests the computation of the expression of SDML algorithm here: https://github.com/metric-learn/metric-learn/blob/master/metric_learn/sdml.py#L54 , in a way that uses only already formed pairs. Note that it works also if some pairs are duplicated (indeed when the adjacency matrix is computed and there is duplicated pairs, the weight of the connection is the addition (ex: =3 if pair duplicated 3 times).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment