Skip to content

Instantly share code, notes, and snippets.

@dipanjanS
Last active August 15, 2019 07:41
Show Gist options
  • Save dipanjanS/ab4bd1eddea9e400a79c96f455338947 to your computer and use it in GitHub Desktop.
Save dipanjanS/ab4bd1eddea9e400a79c96f455338947 to your computer and use it in GitHub Desktop.
import keras
from keras.applications.vgg16 import VGG16
from keras.applications.vgg16 import preprocess_input, decode_predictions
from matplotlib.colors import LinearSegmentedColormap
import numpy as np
import shap
import keras.backend as K
import json
shap.initjs()
# utility function to visualize SHAP values in larger image formats
# this modifies the `shap.image_plot(...)` function
def visualize_model_decisions(shap_values, x, labels=None, figsize=(20, 30)):
colors = []
for l in np.linspace(1, 0, 100):
colors.append((30./255, 136./255, 229./255,l))
for l in np.linspace(0, 1, 100):
colors.append((255./255, 13./255, 87./255,l))
red_transparent_blue = LinearSegmentedColormap.from_list("red_transparent_blue", colors)
multi_output = True
if type(shap_values) != list:
multi_output = False
shap_values = [shap_values]
# make sure labels
if labels is not None:
assert labels.shape[0] == shap_values[0].shape[0], "Labels must have same row count as shap_values arrays!"
if multi_output:
assert labels.shape[1] == len(shap_values), "Labels must have a column for each output in shap_values!"
else:
assert len(labels.shape) == 1, "Labels must be a vector for single output shap_values."
# plot our explanations
fig_size = figsize
fig, axes = plt.subplots(nrows=x.shape[0], ncols=len(shap_values) + 1, figsize=fig_size)
if len(axes.shape) == 1:
axes = axes.reshape(1,axes.size)
for row in range(x.shape[0]):
x_curr = x[row].copy()
# make sure
if len(x_curr.shape) == 3 and x_curr.shape[2] == 1:
x_curr = x_curr.reshape(x_curr.shape[:2])
if x_curr.max() > 1:
x_curr /= 255.
axes[row,0].imshow(x_curr)
axes[row,0].axis('off')
# get a grayscale version of the image
if len(x_curr.shape) == 3 and x_curr.shape[2] == 3:
x_curr_gray = (0.2989 * x_curr[:,:,0] + 0.5870 * x_curr[:,:,1] + 0.1140 * x_curr[:,:,2]) # rgb to gray
else:
x_curr_gray = x_curr
if len(shap_values[0][row].shape) == 2:
abs_vals = np.stack([np.abs(shap_values[i]) for i in range(len(shap_values))], 0).flatten()
else:
abs_vals = np.stack([np.abs(shap_values[i].sum(-1)) for i in range(len(shap_values))], 0).flatten()
max_val = np.nanpercentile(abs_vals, 99.9)
for i in range(len(shap_values)):
if labels is not None:
axes[row,i+1].set_title(labels[row,i])
sv = shap_values[i][row] if len(shap_values[i][row].shape) == 2 else shap_values[i][row].sum(-1)
axes[row,i+1].imshow(x_curr_gray, cmap=plt.get_cmap('gray'), alpha=0.15, extent=(-1, sv.shape[0], sv.shape[1], -1))
im = axes[row,i+1].imshow(sv, cmap=red_transparent_blue, vmin=-max_val, vmax=max_val)
axes[row,i+1].axis('off')
cb = fig.colorbar(im, ax=np.ravel(axes).tolist(), label="SHAP value", orientation="horizontal", aspect=fig_size[0]/0.2)
cb.outline.set_visible(False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment