Created
August 11, 2014 02:43
-
-
Save code-of-kpp/c1d1c9394335d86255b8 to your computer and use it in GitHub Desktop.
Python pyplot receiver operating characteristic (ROC) curve with colorbar
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numbers | |
import six | |
import numpy | |
import matplotlib.collections | |
from matplotlib import pyplot | |
# using example from | |
# http://nbviewer.ipython.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb | |
def make_segments(x, y): | |
''' | |
Create list of line segments from x and y coordinates, | |
in the correct format for LineCollection: | |
an array of the form | |
numlines x (points per line) x 2 (x and y) array | |
''' | |
points = numpy.array([x, y]).T.reshape(-1, 1, 2) | |
segments = numpy.concatenate([points[:-1], points[1:]], axis=1) | |
return segments | |
def colorline(x, y, z=None, axes=None, | |
cmap=pyplot.get_cmap('coolwarm'), | |
norm=pyplot.Normalize(0.0, 1.0), linewidth=3, alpha=1.0, | |
**kwargs): | |
''' | |
Plot a colored line with coordinates x and y | |
Optionally specify colors in the array z | |
Optionally specify a colormap, a norm function and a line width | |
''' | |
# Default colors equally spaced on [0,1]: | |
if z is None: | |
z = numpy.linspace(0.0, 1.0, len(x)) | |
# Special case if a single number: | |
if isinstance(z, numbers.Real): | |
z = numpy.array([z]) | |
z = numpy.asarray(z) | |
segments = make_segments(x, y) | |
lc = matplotlib.collections.LineCollection( | |
segments, array=z, cmap=cmap, norm=norm, | |
linewidth=linewidth, alpha=alpha, **kwargs | |
) | |
if axes is None: | |
axes = pyplot.gca() | |
axes.add_collection(lc) | |
axes.autoscale() | |
return lc | |
def plot_roc(tpr, fpr, thresholds, subplots_kwargs=None, | |
label_every=None, label_kwargs=None, | |
fpr_label='False Positive Rate', | |
tpr_label='True Positive Rate', | |
luck_label='Luck', | |
title='Receiver operating characteristic', | |
**kwargs): | |
if subplots_kwargs is None: | |
subplots_kwargs = {} | |
figure, axes = pyplot.subplots(1, 1, **subplots_kwargs) | |
if 'lw' not in kwargs: | |
kwargs['lw'] = 1 | |
axes.plot(fpr, tpr, **kwargs) | |
if label_every is not None: | |
if label_kwargs is None: | |
label_kwargs = {} | |
if 'bbox' not in label_kwargs: | |
label_kwargs['bbox'] = dict( | |
boxstyle='round,pad=0.5', fc='yellow', alpha=0.5, | |
) | |
for k in six.moves.range(len(tpr)): | |
if k % label_every != 0: | |
continue | |
threshold = str(numpy.round(thresholds[k], 2)) | |
x = fpr[k] | |
y = tpr[k] | |
axes.annotate(threshold, (x, y), **label_kwargs) | |
if luck_label is not None: | |
axes.plot((0, 1), (0, 1), '--', color='Gray', label=luck_label) | |
lc = colorline(fpr, tpr, thresholds, axes=axes) | |
figure.colorbar(lc) | |
axes.set_xlim([-0.05, 1.05]) | |
axes.set_ylim([-0.05, 1.05]) | |
axes.set_xlabel(fpr_label) | |
axes.set_ylabel(tpr_label) | |
axes.set_title(title) | |
axes.legend(loc="lower right") | |
return figure, axes |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I feel this code implies some misconception about what the
threshold
s are, at least as they are returned bysklearn
v1.0.2: this gist plots the text annotations (showing thresholds) on the corners, while they should be printed on the lines connecting the corners. Corner annotations are also possible, but these should be intervals - namely, between the values of the two inbound lines.In a perfect ROC curve (assuming 1 being the score of one group, and 0 the score of the other group), one could have
(+∞, 1)
in the bottom-left corner;1
on the left line;(1, 0)
on the top-left corner;0
on the top line; and(0, -∞)
in the top-right corner.