Created
March 30, 2019 01:18
-
-
Save alisatl/49aecdb03b4a4477228ddacd449d8640 to your computer and use it in GitHub Desktop.
Handy classification evaluation function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Evaluation function | |
import pandas as pd | |
from sklearn.metrics import accuracy_score | |
from sklearn.metrics import confusion_matrix | |
from sklearn.metrics import precision_recall_fscore_support as pr | |
from sklearn.metrics import roc_curve, auc | |
from IPython.display import display | |
import matplotlib.pyplot as plt | |
%matplotlib inline | |
def check_model_fit(clf, X_test, y_test, do_plot_roc_curve=True, do_print_thresh=False): | |
# Print overall test-set accuracy | |
y_pred = clf.predict(X_test) | |
acc = accuracy_score(y_test, y_pred, normalize=True) * 100 | |
# Print confusion matrix | |
cmat = confusion_matrix(y_test, y_pred) | |
cols = pd.MultiIndex.from_tuples([('predictions', 0), ('predictions', 1)]) | |
indx = pd.MultiIndex.from_tuples([('actual', 0), ('actual', 1)]) | |
display(pd.DataFrame(cmat, columns=cols, index=indx)) | |
tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel() | |
# Print test-set accuracy grouped by the target variable | |
scores = (cmat.diagonal() / (cmat.sum(axis=1)*1.0)) * 100 | |
# specificity: | |
irrelevant_acc = scores[0] | |
# sensetivity: | |
relevant_acc = scores[1] | |
balanced_acc = (relevant_acc + irrelevant_acc)/2.0 | |
j = relevant_acc + irrelevant_acc - 100.0 | |
pr_scores = pr(y_test, y_pred, average='binary') | |
# roc-curve | |
if hasattr(clf, 'predict_proba'): | |
y_pred_scores = clf.predict_proba(X_test) | |
fpr, tpr, thresholds = roc_curve(y_test, y_pred_scores[:, 1]) | |
roc_auc = auc(fpr, tpr) | |
roc_auc_str ='{:.2f}'.format(roc_auc) | |
if do_plot_roc_curve: | |
plt.figure() | |
lw = 2 | |
plt.plot(fpr, tpr, color='darkorange', | |
lw=lw, label='ROC curve (area = %0.2f)' % roc_auc) | |
if do_print_thresh: | |
indices_for_thres = np.arange(0, len(thresholds), len(thresholds)/10) | |
plt.scatter(fpr[indices_for_thres], tpr[indices_for_thres], color='darkorange') | |
for x in indices_for_thres: | |
plt.annotate(s='{0:.0e}'.format(thresholds[x]), xy=[fpr[x], tpr[x]]) | |
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--') | |
plt.xlim([0.0, 1.0]) | |
plt.ylim([0.0, 1.05]) | |
plt.xlabel('False Relevant Rate') | |
plt.ylabel('True Relevant Rate') | |
plt.title('ROC-curve') | |
plt.legend(loc="lower right") | |
plt.show() | |
else: | |
roc_auc_str = 'n/a' | |
#[ precision', 'recall', 'f1', 'accuracy', 'relevant_acc', 'irrelevant_acc', 'balanced_accuracy', 'youdens_j'] | |
exp_metrics = {'precision': "{:.3f}".format(pr_scores[0]), | |
'recall': "{:.3f}".format(pr_scores[1]), | |
'f1': "{:.3f}".format(pr_scores[2]), | |
'accuracy': '{:.2f}'.format(acc), | |
'relevant_acc': '{:.2f}'.format(scores[1]), | |
'irrelevant_acc':'{:.2f}'.format(scores[0]), | |
'balanced_acc': '{:.2f}'.format(balanced_acc), | |
"Youden's J": '{:.2f}'.format(j), | |
'roc-area': roc_auc_str, | |
} | |
print ' \n'.join(k + ": " + exp_metrics[k] for k in exp_metrics) | |
return y_pred, exp_metrics |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment