Last active
May 8, 2019 16:42
-
-
Save ankitshekhawat/baedb4840698b5c34dca6418c05ce74f to your computer and use it in GitHub Desktop.
accuracy metric for online triplet
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def recall_top_k(y_true, y_pred): | |
# get batch size | |
batch_size = y_pred.shape[0] | |
# get pairwise distances | |
dists = _pairwise_distances(y_pred) | |
# get indexes of the closest item in the batch predictions | |
# (item itself is always the top_k, so hence getting the second top K) | |
# top_k function gives indexes for the highest values so subtracting from 2(or a high enough number) to invert the matrix | |
_, pred_indexes = tf.nn.top_k(2-dists, 2, sorted=False) | |
# get the indexes of truth vector, multiplying with inverted identity matrix to mask out the item itself | |
_, true_indexes = tf.nn.top_k((1-tf.eye(batch_size))*y_true) | |
# sum of equality divided by batchsize to get accuracy between 0 to 1 | |
return tf.math.reduce_sum(tf.cast(tf.equal(true_indexes[:,0], true_indexes[:,0]), tf.float32))/batch_size |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment