Skip to content

Instantly share code, notes, and snippets.

@eddjberry
Last active November 23, 2022 07:32
Show Gist options
  • Save eddjberry/3c1818a780d3cb17390744d6e215ba4d to your computer and use it in GitHub Desktop.
Save eddjberry/3c1818a780d3cb17390744d6e215ba4d to your computer and use it in GitHub Desktop.
Plot a grid of shap.dependence_plots
# Dependencies ----------------------
import math
import shap
import matplotlib.pyplot as plt
# shap_dependence_plot_grid ---------
def shap_dependence_plot_grid(cols,
shap_values,
X,
interaction_index = None,
alpha = 0.75,
xmin = 'percentile(1)',
xmax = 'percentile(99)'):
'''
Parameters
----------
cols: Either a list/array of column names or indices
shap_values: Your shap values
X: The feature data
interaction_index: Column for interactions or 'auto'
(see help(shap.dependence_plot) for details)
alpha: Alpha blending for the points
xmin: Either None, a value or a string like 'percentile(1)'
xmax: Either None, a value or a string like 'percentile(99)'
Returns
----------
A grid of plots with 3 columns and as many rows as necessary
References
----------
Partially inspired by
https://stackoverflow.com/a/65710424 (CC BY SA 4.0)
'''
# get the required number of rows
nrows = math.ceil(len(cols) / 3)
# create the grid of suplots
fig, axes = plt.subplots(
nrows = nrows,
ncols = 3,
figsize = (12, 0.75 + 3 * nrows))
# Turn the axes into a 1d array
axes = axes.ravel()
# loop through the cols and their indices
# plotting a shap.dependence_plot
for i, col in enumerate(cols):
shap.dependence_plot(col,
shap_values,
X,
interaction_index = interaction_index,
alpha = alpha,
xmin = xmin,
xmax = xmax,
ax=axes[i],
show=False)
# adjust the layout
fig.tight_layout(pad = 3)
return fig
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment