Skip to content

Instantly share code, notes, and snippets.

@alisatl
Created March 30, 2019 01:18
Show Gist options
  • Save alisatl/49aecdb03b4a4477228ddacd449d8640 to your computer and use it in GitHub Desktop.
Save alisatl/49aecdb03b4a4477228ddacd449d8640 to your computer and use it in GitHub Desktop.
Handy classification evaluation function
# Evaluation function
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_fscore_support as pr
from sklearn.metrics import roc_curve, auc
from IPython.display import display
import matplotlib.pyplot as plt
%matplotlib inline
def check_model_fit(clf, X_test, y_test, do_plot_roc_curve=True, do_print_thresh=False):
# Print overall test-set accuracy
y_pred = clf.predict(X_test)
acc = accuracy_score(y_test, y_pred, normalize=True) * 100
# Print confusion matrix
cmat = confusion_matrix(y_test, y_pred)
cols = pd.MultiIndex.from_tuples([('predictions', 0), ('predictions', 1)])
indx = pd.MultiIndex.from_tuples([('actual', 0), ('actual', 1)])
display(pd.DataFrame(cmat, columns=cols, index=indx))
tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel()
# Print test-set accuracy grouped by the target variable
scores = (cmat.diagonal() / (cmat.sum(axis=1)*1.0)) * 100
# specificity:
irrelevant_acc = scores[0]
# sensetivity:
relevant_acc = scores[1]
balanced_acc = (relevant_acc + irrelevant_acc)/2.0
j = relevant_acc + irrelevant_acc - 100.0
pr_scores = pr(y_test, y_pred, average='binary')
# roc-curve
if hasattr(clf, 'predict_proba'):
y_pred_scores = clf.predict_proba(X_test)
fpr, tpr, thresholds = roc_curve(y_test, y_pred_scores[:, 1])
roc_auc = auc(fpr, tpr)
roc_auc_str ='{:.2f}'.format(roc_auc)
if do_plot_roc_curve:
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkorange',
lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
if do_print_thresh:
indices_for_thres = np.arange(0, len(thresholds), len(thresholds)/10)
plt.scatter(fpr[indices_for_thres], tpr[indices_for_thres], color='darkorange')
for x in indices_for_thres:
plt.annotate(s='{0:.0e}'.format(thresholds[x]), xy=[fpr[x], tpr[x]])
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Relevant Rate')
plt.ylabel('True Relevant Rate')
plt.title('ROC-curve')
plt.legend(loc="lower right")
plt.show()
else:
roc_auc_str = 'n/a'
#[ precision', 'recall', 'f1', 'accuracy', 'relevant_acc', 'irrelevant_acc', 'balanced_accuracy', 'youdens_j']
exp_metrics = {'precision': "{:.3f}".format(pr_scores[0]),
'recall': "{:.3f}".format(pr_scores[1]),
'f1': "{:.3f}".format(pr_scores[2]),
'accuracy': '{:.2f}'.format(acc),
'relevant_acc': '{:.2f}'.format(scores[1]),
'irrelevant_acc':'{:.2f}'.format(scores[0]),
'balanced_acc': '{:.2f}'.format(balanced_acc),
"Youden's J": '{:.2f}'.format(j),
'roc-area': roc_auc_str,
}
print ' \n'.join(k + ": " + exp_metrics[k] for k in exp_metrics)
return y_pred, exp_metrics
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment