Last active
September 17, 2019 08:43
-
-
Save amaloney/98aa1567331def3064ff14fdf12e4767 to your computer and use it in GitHub Desktop.
Interactive bokeh plot similar to seaborn's jointplot
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import statsmodels.api as sm | |
from bokeh.application import Application | |
from bokeh.application.handlers import FunctionHandler | |
from bokeh.io import output_notebook, show | |
from bokeh.layouts import gridplot | |
from bokeh.models import ( | |
ColumnDataSource, CustomJS, HoverTool, LinearAxis, Select | |
) | |
from bokeh.plotting import figure | |
from bokeh.sampledata.iris import flowers | |
# NOTE Only been tested to run inside a Jupyter notebook. | |
output_notebook() | |
def joint_figure(df, figure_width=800, figure_height=600, | |
df_description=pd.Series()): | |
# Size the figure. | |
joint_plot_width = int(round(figure_width*0.8, 0)) | |
joint_plot_height = int(round(figure_height*0.8, 0)) | |
x_feature_plot_width = joint_plot_width | |
x_feature_plot_height = figure_height - joint_plot_height | |
y_feature_plot_width = figure_width - joint_plot_width | |
y_feature_plot_height = joint_plot_height | |
# Determine the data sources for each plot. | |
features = sorted(df.columns) | |
x_feature_selection = features[0] | |
y_feature_selection = features[1] | |
feature_data = {} | |
for feature in features: | |
data = df[feature].dropna().astype(float).values | |
kde = sm.nonparametric.KDEUnivariate(data) | |
kde.fit() | |
counts, bins = np.histogram(data) | |
pdf = (kde.density/kde.density.max())*max(counts) | |
feature_data['{}_support'.format(feature)] = kde.support | |
feature_data['{}_density'.format(feature)] = pdf | |
x_feature_source = ColumnDataSource(feature_data) | |
y_feature_source = ColumnDataSource(feature_data) | |
joint_source = ColumnDataSource(df[features]) | |
# Create a hover tool. | |
tools = ['pan,wheel_zoom,box_zoom,reset'] | |
hover = HoverTool(tooltips=[]) | |
if not df_description.empty: | |
description_name = df_description.name | |
joint_source.data[description_name] = df_description.tolist() | |
hover = HoverTool(tooltips = [ | |
(y_feature_selection, '@{}'.format(y_feature_selection)), | |
(x_feature_selection, '@{}'.format(x_feature_selection)), | |
(description_name, '@{}'.format(description_name)), | |
]) | |
tools = tools + [hover] | |
# Create the x_feature plot and style it. | |
x_feature_plot = figure( | |
plot_width=x_feature_plot_width, | |
plot_height=y_feature_plot_height | |
) | |
x_feature_handle = x_feature_plot.line( | |
x='{}_support'.format(x_feature_selection), | |
y='{}_density'.format(x_feature_selection), | |
source=x_feature_source | |
) | |
x_feature_plot.axis.visible = False | |
x_feature_plot.grid.visible = False | |
x_feature_plot.outline_line_color = None | |
# Create the y_feature plot and style it. | |
y_feature_plot = figure( | |
plot_width=y_feature_plot_width, | |
plot_height=y_feature_plot_height, | |
) | |
y_feature_handle = y_feature_plot.line( | |
x='{}_density'.format(y_feature_selection), | |
y='{}_support'.format(y_feature_selection), | |
source=y_feature_source | |
) | |
y_feature_plot.axis.visible = False | |
y_feature_plot.grid.visible = False | |
y_feature_plot.outline_line_color = None | |
# Create the joint plot and link it to the feature plots. | |
joint_plot = figure( | |
x_axis_label=x_feature_selection, | |
y_axis_label=y_feature_selection, | |
plot_width=joint_plot_width, | |
plot_height=joint_plot_height, | |
x_range=x_feature_plot.x_range, | |
y_range=y_feature_plot.y_range, | |
tools=tools | |
) | |
joint_handle = joint_plot.circle( | |
x=x_feature_selection, | |
y=y_feature_selection, | |
source=joint_source, | |
size=10, | |
line_color='White' | |
) | |
# Styles the joint plot so it has axes all around it. | |
joint_plot.extra_x_ranges = {'top_joint_axis': joint_plot.x_range} | |
joint_plot.add_layout(LinearAxis(x_range_name='top_joint_axis'), 'above') | |
joint_plot.extra_y_ranges = {'right_joint_axis': joint_plot.y_range} | |
joint_plot.add_layout(LinearAxis(y_range_name='right_joint_axis'), 'right') | |
# x and y axis callback code separated for clarity. | |
x_feature_callback_code = """ | |
// Define the columns of the joint plot. | |
var x_feature_selection = cb_obj.value; | |
var y_feature_selection = joint_handle.glyph.y.field; | |
// Update the scatter data with the new selection. | |
joint_handle.glyph.x.field = x_feature_selection; | |
joint_source.change.emit(); | |
// Update the plot axis labels. | |
joint_plot.attributes.below[0].attributes.axis_label = x_feature_selection; | |
joint_plot.change.emit(); | |
// Update the marginal distribution. | |
x_feature_handle.glyph.x.field = x_feature_selection + '_support'; | |
x_feature_handle.glyph.y.field = x_feature_selection + '_density'; | |
x_feature_source.change.emit(); | |
// FIXME This doesn't work very well. | |
// Reset the plot if it has been panned. | |
var delta = 0.3; | |
var joint_data = joint_source.data; | |
var joint_x_data = joint_data[x_feature_selection]; | |
var joint_y_data = joint_data[y_feature_selection]; | |
var joint_x_min = (1 - delta)*Math.min(...joint_x_data); | |
var joint_x_max = (1 + delta)*Math.max(...joint_x_data); | |
var joint_y_min = (1 - delta)*Math.min(...joint_y_data); | |
var joint_y_max = (1 + delta)*Math.max(...joint_y_data); | |
joint_plot_xrange.set('start', joint_x_min); | |
joint_plot_xrange.set('end', joint_x_max); | |
joint_plot_yrange.set('start', joint_y_min); | |
joint_plot_yrange.set('end', joint_y_max); | |
// Update the hover tooltip. | |
var desc = hover.attributes.tooltips[hover.attributes.tooltips.length - 1]; | |
hover.attributes.tooltips[0] = [y_feature_selection, '@'+y_feature_selection]; | |
hover.attributes.tooltips[1] = [x_feature_selection, '@'+x_feature_selection]; | |
hover.attributes.tooltips[2] = desc; | |
""" | |
y_feature_callback_code = """ | |
// Define the columns of the joint plot. | |
var x_feature_selection = joint_handle.glyph.x.field; | |
var y_feature_selection = cb_obj.value; | |
// Update the scatter data with the new selection. | |
joint_handle.glyph.y.field = y_feature_selection; | |
joint_source.change.emit(); | |
// Update the plot axis labels. | |
joint_plot.attributes.left[0].attributes.axis_label = y_feature_selection; | |
joint_plot.change.emit(); | |
// Update the marginal distribution. | |
y_feature_handle.glyph.x.field = y_feature_selection + '_density'; | |
y_feature_handle.glyph.y.field = y_feature_selection + '_support'; | |
y_feature_source.change.emit(); | |
// Reset the plot if it has been panned. | |
var delta = 0.3; | |
var joint_data = joint_source.data; | |
var joint_x_data = joint_data[x_feature_selection]; | |
var joint_y_data = joint_data[y_feature_selection]; | |
var joint_x_min = (1 - delta)*Math.min(...joint_x_data); | |
var joint_x_max = (1 + delta)*Math.max(...joint_x_data); | |
var joint_y_min = (1 - delta)*Math.min(...joint_y_data); | |
var joint_y_max = (1 + delta)*Math.max(...joint_y_data); | |
joint_plot_xrange.set('start', joint_x_min); | |
joint_plot_xrange.set('end', joint_x_max); | |
joint_plot_yrange.set('start', joint_y_min); | |
joint_plot_yrange.set('end', joint_y_max); | |
// Update the hover tooltip. | |
var desc = hover.attributes.tooltips[hover.attributes.tooltips.length - 1]; | |
hover.attributes.tooltips[0] = [y_feature_selection, '@'+y_feature_selection]; | |
hover.attributes.tooltips[1] = [x_feature_selection, '@'+x_feature_selection]; | |
hover.attributes.tooltips[2] = desc; | |
""" | |
# Link the callback code to changes in the Select widget. | |
x_feature_callback = CustomJS( | |
args=dict( | |
joint_handle=joint_handle, | |
joint_plot=joint_plot, | |
joint_source=joint_source, | |
joint_plot_xrange=joint_plot.x_range, | |
joint_plot_yrange=joint_plot.y_range, | |
x_feature_handle=x_feature_handle, | |
x_feature_source=x_feature_source, | |
hover=hover | |
), | |
code=x_feature_callback_code | |
) | |
y_feature_callback = CustomJS( | |
args=dict( | |
joint_handle=joint_handle, | |
joint_plot=joint_plot, | |
joint_source=joint_source, | |
joint_plot_xrange=joint_plot.x_range, | |
joint_plot_yrange=joint_plot.y_range, | |
y_feature_handle=y_feature_handle, | |
y_feature_source=y_feature_source, | |
hover=hover | |
), | |
code=y_feature_callback_code | |
) | |
# Create the selector widgets. | |
x_feature_selector = Select( | |
title='x-feature', | |
options=features, | |
value=features[0], | |
callback=x_feature_callback | |
) | |
y_feature_selector = Select( | |
title='y-feature', | |
options=features, | |
value=features[1], | |
callback=y_feature_callback | |
) | |
return ( | |
x_feature_selector, | |
y_feature_selector, | |
x_feature_plot, | |
joint_plot, | |
y_feature_plot | |
) | |
# Create the plot and show it. | |
features = [column for column in flowers.columns if column != 'species'] | |
df = flowers[features].copy() | |
def modify_doc(doc): | |
x_feature_selector, y_feature_selector, x_feature_plot, joint_plot, y_feature_plot = joint_figure(df, df_description=flowers['species']) | |
layout = gridplot( | |
[y_feature_selector, x_feature_selector], | |
[x_feature_plot], | |
[joint_plot, y_feature_plot] | |
) | |
doc.add_root(layout) | |
handler = FunctionHandler(modify_doc) | |
app = Application(handler) | |
show(app, notebook_url='localhost:8888') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
An example of the interactivity using Bokeh's Select widget and tooltips in the Jupyter notebook is show below.