Created
July 24, 2020 07:46
-
-
Save morrisalp/a551d6899a372833216a739788b9947f to your computer and use it in GitHub Desktop.
top K categorical accuracy for numpy arrays (sklearn predict_proba outputs)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def top_k_categorical_accuracy(y_true, y_pred_proba, k=1): | |
return np.equal(np.argsort(y_pred_proba)[:, -k:], y_true[:, None]).any(axis=1).mean() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment