Skip to content

Instantly share code, notes, and snippets.

@SubhadityaMukherjee
Created August 19, 2022 13:58
Show Gist options
  • Save SubhadityaMukherjee/ad1fd6850996c6d0b455c95e3d0c1927 to your computer and use it in GitHub Desktop.
Save SubhadityaMukherjee/ad1fd6850996c6d0b455c95e3d0c1927 to your computer and use it in GitHub Desktop.
plot gradcam
img = PILImage.create(
"/media/hdd/Datasets/Fish_Dataset/Fish_Dataset/Shrimp/Shrimp/00012.png"
)
(x,) = first(dls.test_dl([img]))
# cam_map = torch.einsum('ck,kij->cij', learn.model[1][-1].weight, act)
x_dec = TensorImage(dls.train.decode((x,))[0][0])
image_count = len(learn.model[0])
col = 4
row = math.ceil(image_count / col)
plt.figure(figsize=(col * 4, row * 4))
plt.figure(figsize=(col * 4, row * 4))
for layer in range(image_count): # no of layers
cls = 1
try:
with HookBwd(learn.model[0][layer]) as hookg: # for other layers
with Hook(learn.model[0][layer]) as hook:
output = learn.model.eval()(x.cuda())
act = hook.stored
output[0, cls].backward()
grad = hookg.stored
w = grad[0].mean(dim=[1, 2], keepdim=True)
cam_map = (w * act[0]).sum(0)
except:
pass
plt.subplot(row, col, layer + 1)
x_dec.show(ctx=plt)
plt.imshow(
cam_map.detach().cpu(),
alpha=0.8,
extent=(0, 224, 224, 0),
interpolation="bilinear",
cmap="magma",
)
plt.title(f"Layer : {layer}")
plt.axis("off")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment