Skip to content

Instantly share code, notes, and snippets.

@SkalskiP
Created February 17, 2020 22:02
Show Gist options
  • Save SkalskiP/bb0ed2c6ddbf4a5772111cfea7e79106 to your computer and use it in GitHub Desktop.
Save SkalskiP/bb0ed2c6ddbf4a5772111cfea7e79106 to your computer and use it in GitHub Desktop.
def plot_shap_top_explanations(
model: Model,
image: np.array,
class_names_mapping: Dict[int, str],
top_preds_count: int = 3,
fig_name: Optional[str] = None
) -> None:
image_columns = 3
image_rows = math.ceil(top_preds_count / image_columns)
segments_slic = slic(image, n_segments=100, compactness=30, sigma=3)
def _h(z):
return model.predict(preprocess_input(mask_image(z, segments_slic, image, 255)))
explainer = shap.KernelExplainer(_h, np.zeros((1,100)))
shap_values = explainer.shap_values(np.ones((1,100)), nsamples=1000)
preds = model.predict(np.expand_dims(image, axis=0))
top_preds_indexes = np.flip(np.argsort(preds))[0,:top_preds_count]
top_preds_values = preds.take(top_preds_indexes)
top_preds_names = np.vectorize(lambda x: class_names[x])(top_preds_indexes)
plt.style.use('dark_background')
fig, axes = plt.subplots(image_rows, image_columns, figsize=(image_columns * 5, image_rows * 5))
[ax.set_axis_off() for ax in axes.flat]
max_val = np.max([np.max(np.abs(shap_values[i][:,:-1])) for i in range(len(shap_values))])
color_map = get_colormap()
for i, (index, value, name, ax) in \
enumerate(zip(top_preds_indexes, top_preds_values, top_preds_names, axes.flat)):
m = fill_segmentation(shap_values[index][0], segments_slic)
subplot_title = "{}. class: {} pred: {:.3f}".format(i + 1, name, value)
ax.imshow(image / 255)
ax.imshow(m, cmap=color_map, vmin=-max_val, vmax=max_val)
ax.set_title(subplot_title, pad=20)
if fig_name:
plt.savefig(fig_name)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment