Skip to content

Instantly share code, notes, and snippets.

@Brideau
Created December 12, 2021 23:49
Show Gist options
  • Save Brideau/929378b8272d7f13268c1e61a7853d8c to your computer and use it in GitHub Desktop.
Save Brideau/929378b8272d7f13268c1e61a7853d8c to your computer and use it in GitHub Desktop.
An implementation of Precision@k compatible with Scikit-learn.
def precision_at_k(y_true, y_score, k, pos_label=1):
from sklearn.utils import column_or_1d
from sklearn.utils.multiclass import type_of_target
y_true_type = type_of_target(y_true)
if not (y_true_type == "binary"):
raise ValueError("y_true must be a binary column.")
# Makes this compatible with various array types
y_true_arr = column_or_1d(y_true)
y_score_arr = column_or_1d(y_score)
y_true_arr = y_true_arr == pos_label
desc_sort_order = np.argsort(y_score_arr)[::-1]
y_true_sorted = y_true_arr[desc_sort_order]
y_score_sorted = y_score_arr[desc_sort_order]
true_positives = y_true_sorted[:k].sum()
return true_positives / k
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment