Last active
August 15, 2019 07:41
-
-
Save dipanjanS/ab4bd1eddea9e400a79c96f455338947 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import keras | |
from keras.applications.vgg16 import VGG16 | |
from keras.applications.vgg16 import preprocess_input, decode_predictions | |
from matplotlib.colors import LinearSegmentedColormap | |
import numpy as np | |
import shap | |
import keras.backend as K | |
import json | |
shap.initjs() | |
# utility function to visualize SHAP values in larger image formats | |
# this modifies the `shap.image_plot(...)` function | |
def visualize_model_decisions(shap_values, x, labels=None, figsize=(20, 30)): | |
colors = [] | |
for l in np.linspace(1, 0, 100): | |
colors.append((30./255, 136./255, 229./255,l)) | |
for l in np.linspace(0, 1, 100): | |
colors.append((255./255, 13./255, 87./255,l)) | |
red_transparent_blue = LinearSegmentedColormap.from_list("red_transparent_blue", colors) | |
multi_output = True | |
if type(shap_values) != list: | |
multi_output = False | |
shap_values = [shap_values] | |
# make sure labels | |
if labels is not None: | |
assert labels.shape[0] == shap_values[0].shape[0], "Labels must have same row count as shap_values arrays!" | |
if multi_output: | |
assert labels.shape[1] == len(shap_values), "Labels must have a column for each output in shap_values!" | |
else: | |
assert len(labels.shape) == 1, "Labels must be a vector for single output shap_values." | |
# plot our explanations | |
fig_size = figsize | |
fig, axes = plt.subplots(nrows=x.shape[0], ncols=len(shap_values) + 1, figsize=fig_size) | |
if len(axes.shape) == 1: | |
axes = axes.reshape(1,axes.size) | |
for row in range(x.shape[0]): | |
x_curr = x[row].copy() | |
# make sure | |
if len(x_curr.shape) == 3 and x_curr.shape[2] == 1: | |
x_curr = x_curr.reshape(x_curr.shape[:2]) | |
if x_curr.max() > 1: | |
x_curr /= 255. | |
axes[row,0].imshow(x_curr) | |
axes[row,0].axis('off') | |
# get a grayscale version of the image | |
if len(x_curr.shape) == 3 and x_curr.shape[2] == 3: | |
x_curr_gray = (0.2989 * x_curr[:,:,0] + 0.5870 * x_curr[:,:,1] + 0.1140 * x_curr[:,:,2]) # rgb to gray | |
else: | |
x_curr_gray = x_curr | |
if len(shap_values[0][row].shape) == 2: | |
abs_vals = np.stack([np.abs(shap_values[i]) for i in range(len(shap_values))], 0).flatten() | |
else: | |
abs_vals = np.stack([np.abs(shap_values[i].sum(-1)) for i in range(len(shap_values))], 0).flatten() | |
max_val = np.nanpercentile(abs_vals, 99.9) | |
for i in range(len(shap_values)): | |
if labels is not None: | |
axes[row,i+1].set_title(labels[row,i]) | |
sv = shap_values[i][row] if len(shap_values[i][row].shape) == 2 else shap_values[i][row].sum(-1) | |
axes[row,i+1].imshow(x_curr_gray, cmap=plt.get_cmap('gray'), alpha=0.15, extent=(-1, sv.shape[0], sv.shape[1], -1)) | |
im = axes[row,i+1].imshow(sv, cmap=red_transparent_blue, vmin=-max_val, vmax=max_val) | |
axes[row,i+1].axis('off') | |
cb = fig.colorbar(im, ax=np.ravel(axes).tolist(), label="SHAP value", orientation="horizontal", aspect=fig_size[0]/0.2) | |
cb.outline.set_visible(False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment