Skip to content

Instantly share code, notes, and snippets.

@vikramsoni2
Created February 23, 2022 02:07
Show Gist options
  • Save vikramsoni2/5a5361ae675854aa355d6634f0ff9251 to your computer and use it in GitHub Desktop.
Save vikramsoni2/5a5361ae675854aa355d6634f0ff9251 to your computer and use it in GitHub Desktop.
bokeh binary confusion matrix
from sklearn.metrics import confusion_matrix, classification_report, precision_score, recall_score, auc
from bokeh.transform import dodge
from bokeh.plotting import figure, ColumnDataSource, output_notebook, show
output_notebook()
def plot_confustion_matrix(y_true, y_pred, cutoff=0.5, normed=False, classes = ["Negative", "Positive"], colors = ['#fcb471', '#fce3cc', '#ccdaea', '#76a0c9']):
y_pred_bin = y_pred if len(np.unique(y_pred))==2 else y_pred >= cutoff
cm = confusion_matrix(y_true, y_pred_bin)
df_cm = pd.DataFrame(cm.T, index = classes, columns = classes)
df_cm.index.name = 'Actual'
df_cm.columns.name = 'Predicted'
df_cm = df_cm.stack().rename("value").reset_index()
df_cm['colors'] = colors
df_cm['label'] = ['TN','FN','FP','TP']
total = df_cm['value'].sum()
df_cm['ratio'] = np.round((df_cm['value'] / total * 100), decimals=2)
df_cm['ratio'] = df_cm['ratio'].astype(str) + "%"
data = ColumnDataSource(df_cm)
p = figure(plot_width=300, plot_height=230,
x_axis_location='above', y_axis_location='left',
x_range=classes, y_range=list(reversed(classes)),
toolbar_location=None, tools='')
r = p.rect("Actual", "Predicted", 0.95, 0.95, source=data, fill_alpha=0.6, fill_color='colors', line_color='gray')
text_props = {"source": data, "text_align": "left", "text_baseline": "middle"}
x = dodge("Actual", -0.30, range=p.x_range)
p.text(x=x, y=dodge("Predicted", 0.15, range=p.y_range), text="label", text_font_size="8pt", **text_props)
if normed:
p.text(x=x, y=dodge("Predicted", -0.10, range=p.y_range), text="ratio", text_font_size="18pt", **text_props)
else:
p.text(x=x, y=dodge("Predicted", -0.10, range=p.y_range), text="value",text_font_size="18pt", **text_props)
p.outline_line_color = None
p.grid.grid_line_color = None
p.axis.axis_line_color = None
p.axis.major_tick_line_color = None
p.axis.major_label_standoff = 0
if normed:
p.xaxis.axis_label = 'Prediction (Rates)'
else:
p.xaxis.axis_label = 'Prediction'
p.xaxis.axis_label_text_font_size = "10pt"
p.xaxis.axis_label_text_font_style = "bold"
p.xaxis.major_label_text_font_size = "8pt"
p.yaxis.axis_label = 'Actual'
p.yaxis.axis_label_text_font_size = "10pt"
p.yaxis.axis_label_text_font_style = "bold"
p.yaxis.major_label_text_font_size = "8pt"
p.yaxis.major_label_orientation = "vertical"
show(p)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment