Skip to content

Instantly share code, notes, and snippets.

@lukasz-migas
Forked from amaloney/jointplot.py
Created September 17, 2019 08:43
Show Gist options
  • Save lukasz-migas/ec9a34584f61878f2a89868b3f5daf0a to your computer and use it in GitHub Desktop.
Save lukasz-migas/ec9a34584f61878f2a89868b3f5daf0a to your computer and use it in GitHub Desktop.
Interactive bokeh plot similar to seaborn's jointplot
import numpy as np
import pandas as pd
import statsmodels.api as sm
from bokeh.application import Application
from bokeh.application.handlers import FunctionHandler
from bokeh.io import output_notebook, show
from bokeh.layouts import gridplot
from bokeh.models import (
ColumnDataSource, CustomJS, HoverTool, LinearAxis, Select
)
from bokeh.plotting import figure
from bokeh.sampledata.iris import flowers
# NOTE Only been tested to run inside a Jupyter notebook.
output_notebook()
def joint_figure(df, figure_width=800, figure_height=600,
df_description=pd.Series()):
# Size the figure.
joint_plot_width = int(round(figure_width*0.8, 0))
joint_plot_height = int(round(figure_height*0.8, 0))
x_feature_plot_width = joint_plot_width
x_feature_plot_height = figure_height - joint_plot_height
y_feature_plot_width = figure_width - joint_plot_width
y_feature_plot_height = joint_plot_height
# Determine the data sources for each plot.
features = sorted(df.columns)
x_feature_selection = features[0]
y_feature_selection = features[1]
feature_data = {}
for feature in features:
data = df[feature].dropna().astype(float).values
kde = sm.nonparametric.KDEUnivariate(data)
kde.fit()
counts, bins = np.histogram(data)
pdf = (kde.density/kde.density.max())*max(counts)
feature_data['{}_support'.format(feature)] = kde.support
feature_data['{}_density'.format(feature)] = pdf
x_feature_source = ColumnDataSource(feature_data)
y_feature_source = ColumnDataSource(feature_data)
joint_source = ColumnDataSource(df[features])
# Create a hover tool.
tools = ['pan,wheel_zoom,box_zoom,reset']
hover = HoverTool(tooltips=[])
if not df_description.empty:
description_name = df_description.name
joint_source.data[description_name] = df_description.tolist()
hover = HoverTool(tooltips = [
(y_feature_selection, '@{}'.format(y_feature_selection)),
(x_feature_selection, '@{}'.format(x_feature_selection)),
(description_name, '@{}'.format(description_name)),
])
tools = tools + [hover]
# Create the x_feature plot and style it.
x_feature_plot = figure(
plot_width=x_feature_plot_width,
plot_height=y_feature_plot_height
)
x_feature_handle = x_feature_plot.line(
x='{}_support'.format(x_feature_selection),
y='{}_density'.format(x_feature_selection),
source=x_feature_source
)
x_feature_plot.axis.visible = False
x_feature_plot.grid.visible = False
x_feature_plot.outline_line_color = None
# Create the y_feature plot and style it.
y_feature_plot = figure(
plot_width=y_feature_plot_width,
plot_height=y_feature_plot_height,
)
y_feature_handle = y_feature_plot.line(
x='{}_density'.format(y_feature_selection),
y='{}_support'.format(y_feature_selection),
source=y_feature_source
)
y_feature_plot.axis.visible = False
y_feature_plot.grid.visible = False
y_feature_plot.outline_line_color = None
# Create the joint plot and link it to the feature plots.
joint_plot = figure(
x_axis_label=x_feature_selection,
y_axis_label=y_feature_selection,
plot_width=joint_plot_width,
plot_height=joint_plot_height,
x_range=x_feature_plot.x_range,
y_range=y_feature_plot.y_range,
tools=tools
)
joint_handle = joint_plot.circle(
x=x_feature_selection,
y=y_feature_selection,
source=joint_source,
size=10,
line_color='White'
)
# Styles the joint plot so it has axes all around it.
joint_plot.extra_x_ranges = {'top_joint_axis': joint_plot.x_range}
joint_plot.add_layout(LinearAxis(x_range_name='top_joint_axis'), 'above')
joint_plot.extra_y_ranges = {'right_joint_axis': joint_plot.y_range}
joint_plot.add_layout(LinearAxis(y_range_name='right_joint_axis'), 'right')
# x and y axis callback code separated for clarity.
x_feature_callback_code = """
// Define the columns of the joint plot.
var x_feature_selection = cb_obj.value;
var y_feature_selection = joint_handle.glyph.y.field;
// Update the scatter data with the new selection.
joint_handle.glyph.x.field = x_feature_selection;
joint_source.change.emit();
// Update the plot axis labels.
joint_plot.attributes.below[0].attributes.axis_label = x_feature_selection;
joint_plot.change.emit();
// Update the marginal distribution.
x_feature_handle.glyph.x.field = x_feature_selection + '_support';
x_feature_handle.glyph.y.field = x_feature_selection + '_density';
x_feature_source.change.emit();
// FIXME This doesn't work very well.
// Reset the plot if it has been panned.
var delta = 0.3;
var joint_data = joint_source.data;
var joint_x_data = joint_data[x_feature_selection];
var joint_y_data = joint_data[y_feature_selection];
var joint_x_min = (1 - delta)*Math.min(...joint_x_data);
var joint_x_max = (1 + delta)*Math.max(...joint_x_data);
var joint_y_min = (1 - delta)*Math.min(...joint_y_data);
var joint_y_max = (1 + delta)*Math.max(...joint_y_data);
joint_plot_xrange.set('start', joint_x_min);
joint_plot_xrange.set('end', joint_x_max);
joint_plot_yrange.set('start', joint_y_min);
joint_plot_yrange.set('end', joint_y_max);
// Update the hover tooltip.
var desc = hover.attributes.tooltips[hover.attributes.tooltips.length - 1];
hover.attributes.tooltips[0] = [y_feature_selection, '@'+y_feature_selection];
hover.attributes.tooltips[1] = [x_feature_selection, '@'+x_feature_selection];
hover.attributes.tooltips[2] = desc;
"""
y_feature_callback_code = """
// Define the columns of the joint plot.
var x_feature_selection = joint_handle.glyph.x.field;
var y_feature_selection = cb_obj.value;
// Update the scatter data with the new selection.
joint_handle.glyph.y.field = y_feature_selection;
joint_source.change.emit();
// Update the plot axis labels.
joint_plot.attributes.left[0].attributes.axis_label = y_feature_selection;
joint_plot.change.emit();
// Update the marginal distribution.
y_feature_handle.glyph.x.field = y_feature_selection + '_density';
y_feature_handle.glyph.y.field = y_feature_selection + '_support';
y_feature_source.change.emit();
// Reset the plot if it has been panned.
var delta = 0.3;
var joint_data = joint_source.data;
var joint_x_data = joint_data[x_feature_selection];
var joint_y_data = joint_data[y_feature_selection];
var joint_x_min = (1 - delta)*Math.min(...joint_x_data);
var joint_x_max = (1 + delta)*Math.max(...joint_x_data);
var joint_y_min = (1 - delta)*Math.min(...joint_y_data);
var joint_y_max = (1 + delta)*Math.max(...joint_y_data);
joint_plot_xrange.set('start', joint_x_min);
joint_plot_xrange.set('end', joint_x_max);
joint_plot_yrange.set('start', joint_y_min);
joint_plot_yrange.set('end', joint_y_max);
// Update the hover tooltip.
var desc = hover.attributes.tooltips[hover.attributes.tooltips.length - 1];
hover.attributes.tooltips[0] = [y_feature_selection, '@'+y_feature_selection];
hover.attributes.tooltips[1] = [x_feature_selection, '@'+x_feature_selection];
hover.attributes.tooltips[2] = desc;
"""
# Link the callback code to changes in the Select widget.
x_feature_callback = CustomJS(
args=dict(
joint_handle=joint_handle,
joint_plot=joint_plot,
joint_source=joint_source,
joint_plot_xrange=joint_plot.x_range,
joint_plot_yrange=joint_plot.y_range,
x_feature_handle=x_feature_handle,
x_feature_source=x_feature_source,
hover=hover
),
code=x_feature_callback_code
)
y_feature_callback = CustomJS(
args=dict(
joint_handle=joint_handle,
joint_plot=joint_plot,
joint_source=joint_source,
joint_plot_xrange=joint_plot.x_range,
joint_plot_yrange=joint_plot.y_range,
y_feature_handle=y_feature_handle,
y_feature_source=y_feature_source,
hover=hover
),
code=y_feature_callback_code
)
# Create the selector widgets.
x_feature_selector = Select(
title='x-feature',
options=features,
value=features[0],
callback=x_feature_callback
)
y_feature_selector = Select(
title='y-feature',
options=features,
value=features[1],
callback=y_feature_callback
)
return (
x_feature_selector,
y_feature_selector,
x_feature_plot,
joint_plot,
y_feature_plot
)
# Create the plot and show it.
features = [column for column in flowers.columns if column != 'species']
df = flowers[features].copy()
def modify_doc(doc):
x_feature_selector, y_feature_selector, x_feature_plot, joint_plot, y_feature_plot = joint_figure(df, df_description=flowers['species'])
layout = gridplot(
[y_feature_selector, x_feature_selector],
[x_feature_plot],
[joint_plot, y_feature_plot]
)
doc.add_root(layout)
handler = FunctionHandler(modify_doc)
app = Application(handler)
show(app, notebook_url='localhost:8888')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment