Skip to content

Instantly share code, notes, and snippets.

@jonathanoheix
Created December 18, 2018 09:53
Show Gist options
  • Save jonathanoheix/658c06e281a4a091cc480ba54095c3c5 to your computer and use it in GitHub Desktop.
Save jonathanoheix/658c06e281a4a091cc480ba54095c3c5 to your computer and use it in GitHub Desktop.
# PR curve
from sklearn.metrics import average_precision_score, precision_recall_curve
from sklearn.utils.fixes import signature
average_precision = average_precision_score(y_test, y_pred)
precision, recall, _ = precision_recall_curve(y_test, y_pred)
# In matplotlib < 1.5, plt.fill_between does not have a 'step' argument
step_kwargs = ({'step': 'post'}
if 'step' in signature(plt.fill_between).parameters
else {})
plt.figure(1, figsize = (15, 10))
plt.step(recall, precision, color='b', alpha=0.2,
where='post')
plt.fill_between(recall, precision, alpha=0.2, color='b', **step_kwargs)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.0])
plt.title('2-class Precision-Recall curve: AP={0:0.2f}'.format(average_precision))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment