Skip to content

Instantly share code, notes, and snippets.

@ecolss
Last active October 14, 2023 19:10
Show Gist options
  • Save ecolss/6c061fcafee7d14ebda633b33c416660 to your computer and use it in GitHub Desktop.
Save ecolss/6c061fcafee7d14ebda633b33c416660 to your computer and use it in GitHub Desktop.
Grad CAM - example code
import cv2
from functools import partial
import json
import numpy as np
import os
from PIL import Image
import pylab as pl
import torch as th
from torch import nn
import torchvision as thv
transform = thv.transforms.Compose([
thv.transforms.Resize((224, 224)),
thv.transforms.ToTensor(),
thv.transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
class Hooks:
def __init__(self):
self.ctx = {}
def get_actvs(self, module, input, output, *, name=""):
self.ctx[f"{name}_actvs"] = output.detach().clone()
def get_grads(self, module, input, output, *, name=""):
self.ctx[f"{name}_grads"] = output[0].detach().clone()
def register(self, module, name):
module.register_forward_hook(partial(hooks.get_actvs, name=name))
# This can't work if the model has module with inplace operations (e.g. ReLU)
module.register_full_backward_hook(partial(hooks.get_grads, name=name))
class Reshape(nn.Module):
def __init__(self, to_shape):
super().__init__()
self.to_shape = to_shape
def forward(self, x):
return x.reshape((x.shape[0],) + self.to_shape)
class CAMVGG(nn.Module):
def __init__(self):
super().__init__()
vgg = thv.models.vgg19(pretrained=True)
self.features = vgg.features[:36] # end at the last ReLU
self.classif = nn.Sequential(
# Max pooling after the last ReLU, the same setting as in pretrained vgg
nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
vgg.avgpool,
Reshape((-1,)),
vgg.classifier,
)
self.ctx = {}
def get_grad(self, grad):
self.ctx["grad"] = grad.detach().clone()
def forward(self, x):
self.ctx.clear()
h = self.features(x)
self.ctx["actv"] = h.detach().clone()
# Since vgg has inplace operations at some modules (e.g. ReLU),
# so can use the module.register_* API, but tensor.register_* can work.
h.register_hook(self.get_grad)
pred = self.classif(h)
return pred
def grad_cam_importance(grad, actv):
# Shape = (bsz, c, w, h)
alpha = grad.sum([2,3], keepdims=True) / (grad.shape[2] * grad.shape[3])
imp = (alpha * actv).mean(1)
return imp.numpy()
def run_grad_cam(f, grad_cam_model, id2label, topk=1):
raw_ims = [Image.open(os.path.expanduser(f))]
ims = th.concat([transform(im).unsqueeze(0) for im in raw_ims], 0)
pred = grad_cam_model(ims)
idx = pred.argsort(-1)
pred.take_along_dim(idx[:,-topk].view(-1,1), 1).backward()
imps = grad_cam_importance(vgg.ctx["grad"], vgg.ctx["actv"])
imps = np.array([
cv2.resize(el, (ims.shape[2], ims.shape[3]), interpolation=cv2.INTER_CUBIC)
for el in imps
])
# For plotting
print(idx[:, -10:])
print(f"use top{topk} pred: {id2label[str(idx[0,-topk].item())]}")
raw_im = raw_ims[0]
im = ims[0].permute(1,2,0).numpy()
im = (im - im.min()) / (im.max() - im.min())
imp = imps[0]
imp = np.maximum(imp, 0)
imp = imp / imp.max()
imp = imp.reshape(imp.shape + (1,))
print(im.shape, imp.shape)
fig, ax = pl.subplots(2,2)
ax[0,0].imshow(raw_im)
ax[0,1].imshow(im)
ax[1,0].imshow(imp)
# Overlay images
ax[1,1].imshow(im)
ax[1,1].imshow(imp, alpha=0.6)
pl.show()
if __name__ == "__main__":
with open("imagenet_1k.json") as fp:
imagenet_1k_labels = json.load(fp)
th.manual_seed(17)
np.random.seed(17)
vgg = CAMVGG()
vgg.eval()
for k in range(1, 20):
run_grad_cam("~/Pictures/flock.jpg", vgg, imagenet_1k_labels, topk=k)
@ecolss
Copy link
Author

ecolss commented Oct 14, 2023

cam_flock

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment