Skip to content

Instantly share code, notes, and snippets.

@SkalskiP
Last active February 17, 2020 22:04
Show Gist options
  • Save SkalskiP/8dd92215894ff57e4f50fde0cf22c8ed to your computer and use it in GitHub Desktop.
Save SkalskiP/8dd92215894ff57e4f50fde0cf22c8ed to your computer and use it in GitHub Desktop.
def plot_lime_top_explanations(
model: Model,
image: np.array,
class_names_mapping: Dict[int, str],
top_preds_count: int = 3,
fig_name: Optional[str] = None
) -> None:
image_columns = 3
image_rows = math.ceil(top_preds_count / image_columns)
explanation = explainer.explain_instance(
image,
classifier_fn = model.predict,
top_labels=100,
hide_color=0,
num_samples=1000
)
preds = model.predict(np.expand_dims(image, axis=0))
top_preds_indexes = np.flip(np.argsort(preds))[0,:top_preds_count]
top_preds_values = preds.take(top_preds_indexes)
top_preds_names = np.vectorize(lambda x: class_names[x])(top_preds_indexes)
plt.style.use('dark_background')
fig, axes = plt.subplots(image_rows, image_columns, figsize=(image_columns * 5, image_rows * 5))
[ax.set_axis_off() for ax in axes.flat]
for i, (index, value, name, ax) in \
enumerate(zip(top_preds_indexes, top_preds_values, top_preds_names, axes.flat)):
temp, mask = explanation.get_image_and_mask(
explanation.top_labels[i],
positive_only=False,
num_features=5,
hide_rest=False
)
subplot_title = "{}. class: {} pred: {:.3f}".format(i + 1, name, value)
ax.imshow(mark_boundaries(temp / 255, mask))
ax.set_title(subplot_title, pad=20)
if fig_name:
plt.savefig(fig_name)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment