Skip to content

Instantly share code, notes, and snippets.

@dipanjanS
Created August 15, 2019 10:51
Show Gist options
  • Save dipanjanS/8f2fb60bba1473b0b341dbf2302bf443 to your computer and use it in GitHub Desktop.
Save dipanjanS/8f2fb60bba1473b0b341dbf2302bf443 to your computer and use it in GitHub Desktop.
from tf_explain.core.grad_cam import GradCAM
explainer = GradCAM()
# get imagenet IDs for cat breeds
imgnet_map['tabby'], imgnet_map['Egyptian_cat']
Out [24]: ('281', '285')
# visualize GradCAM outputs in Block 1
grid1 = explainer.explain(([img], None), model, 'block1_conv2', 281)
grid2 = explainer.explain(([img], None), model, 'block1_conv2', 285)
fig = plt.figure(figsize = (18, 8))
ax1 = fig.add_subplot(1, 3, 1)
ax1.imshow(img_inp / 255.)
ax1.imshow(grid1, alpha=0.6)
ax2 = fig.add_subplot(1, 3, 2)
ax2.imshow(img_inp / 255.)
ax2.imshow(grid2, alpha=0.6)
ax3 = fig.add_subplot(1, 3, 3)
ax3.imshow(img_inp / 255.)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment