Last active
January 20, 2017 15:02
-
-
Save iamaaditya/eebb61c6d7c3da995fc6da380db530b6 to your computer and use it in GitHub Desktop.
Code snippet to print various ML related metrics given the y_labels and probabilities of each label (output of softmax)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Code snippet to print various ML related metrics given the y_labels and probabilities of each label (output of softmax) | |
# Aaditya Prakash | |
from sklearn.metrics import f1_score, roc_auc_score, precision_score, recall_score, accuracy_score, average_precision_score, precision_recall_curve, hamming_loss | |
def print(y_labels, probs): | |
threshold = 0.5 | |
macro_auc = roc_auc_score(y_labels, probs, average = 'macro') | |
micro_auc = roc_auc_score(y_labels, probs, average = 'micro') | |
zero_values_indices = probs < threshold | |
one_values_indices = probs >= threshold | |
probs[zero_values_indices] = 0 | |
probs[one_values_indices] = 1 | |
macro_f1 = f1_score(y_labels, probs, average = 'macro') | |
micro_f1 = f1_score(y_labels, probs, average = 'micro') | |
precision = precision_score(y_labels, probs, average = 'micro') | |
recall = recall_score(y_labels, probs, average = 'micro') | |
average_precision = average_precision_score(y_labels, probs, average = 'weighted') | |
precision_recall = precision_recall_curve(y_labels, probs) | |
hamming_loss_v = hamming_loss(y_labels, probs) | |
accuracy_1 = accuracy_score(y_labels, probs) | |
accuracy_5 = accuracy_5 / len(y_labels) | |
accuracy_10 = accuracy_10 / len(y_labels) | |
auc = roc_auc_score(y_labels, probs_f) | |
time_str = datetime.datetime.now().isoformat() | |
print("{}: step {}, macro_f1_score {:g}, micro_f1_score {:g}, micro_auc {:g}, macro_auc {:g}, precision {:g}, recall {:g}, accuracy_1 {:g}, accuracy_5 {:g}, accuracy_10 {:g}, average_precision {:g}, hamming_loss_v {:g}".format( | |
time_str, step, macro_f1, micro_f1, micro_auc, macro_auc, precision, recall, accuracy_1, accuracy_5, accuracy_10, average_precision, hamming_loss_v)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment