Last active
October 14, 2023 19:10
-
-
Save ecolss/6c061fcafee7d14ebda633b33c416660 to your computer and use it in GitHub Desktop.
Grad CAM - example code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
from functools import partial | |
import json | |
import numpy as np | |
import os | |
from PIL import Image | |
import pylab as pl | |
import torch as th | |
from torch import nn | |
import torchvision as thv | |
transform = thv.transforms.Compose([ | |
thv.transforms.Resize((224, 224)), | |
thv.transforms.ToTensor(), | |
thv.transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225] | |
) | |
]) | |
class Hooks: | |
def __init__(self): | |
self.ctx = {} | |
def get_actvs(self, module, input, output, *, name=""): | |
self.ctx[f"{name}_actvs"] = output.detach().clone() | |
def get_grads(self, module, input, output, *, name=""): | |
self.ctx[f"{name}_grads"] = output[0].detach().clone() | |
def register(self, module, name): | |
module.register_forward_hook(partial(hooks.get_actvs, name=name)) | |
# This can't work if the model has module with inplace operations (e.g. ReLU) | |
module.register_full_backward_hook(partial(hooks.get_grads, name=name)) | |
class Reshape(nn.Module): | |
def __init__(self, to_shape): | |
super().__init__() | |
self.to_shape = to_shape | |
def forward(self, x): | |
return x.reshape((x.shape[0],) + self.to_shape) | |
class CAMVGG(nn.Module): | |
def __init__(self): | |
super().__init__() | |
vgg = thv.models.vgg19(pretrained=True) | |
self.features = vgg.features[:36] # end at the last ReLU | |
self.classif = nn.Sequential( | |
# Max pooling after the last ReLU, the same setting as in pretrained vgg | |
nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), | |
vgg.avgpool, | |
Reshape((-1,)), | |
vgg.classifier, | |
) | |
self.ctx = {} | |
def get_grad(self, grad): | |
self.ctx["grad"] = grad.detach().clone() | |
def forward(self, x): | |
self.ctx.clear() | |
h = self.features(x) | |
self.ctx["actv"] = h.detach().clone() | |
# Since vgg has inplace operations at some modules (e.g. ReLU), | |
# so can use the module.register_* API, but tensor.register_* can work. | |
h.register_hook(self.get_grad) | |
pred = self.classif(h) | |
return pred | |
def grad_cam_importance(grad, actv): | |
# Shape = (bsz, c, w, h) | |
alpha = grad.sum([2,3], keepdims=True) / (grad.shape[2] * grad.shape[3]) | |
imp = (alpha * actv).mean(1) | |
return imp.numpy() | |
def run_grad_cam(f, grad_cam_model, id2label, topk=1): | |
raw_ims = [Image.open(os.path.expanduser(f))] | |
ims = th.concat([transform(im).unsqueeze(0) for im in raw_ims], 0) | |
pred = grad_cam_model(ims) | |
idx = pred.argsort(-1) | |
pred.take_along_dim(idx[:,-topk].view(-1,1), 1).backward() | |
imps = grad_cam_importance(vgg.ctx["grad"], vgg.ctx["actv"]) | |
imps = np.array([ | |
cv2.resize(el, (ims.shape[2], ims.shape[3]), interpolation=cv2.INTER_CUBIC) | |
for el in imps | |
]) | |
# For plotting | |
print(idx[:, -10:]) | |
print(f"use top{topk} pred: {id2label[str(idx[0,-topk].item())]}") | |
raw_im = raw_ims[0] | |
im = ims[0].permute(1,2,0).numpy() | |
im = (im - im.min()) / (im.max() - im.min()) | |
imp = imps[0] | |
imp = np.maximum(imp, 0) | |
imp = imp / imp.max() | |
imp = imp.reshape(imp.shape + (1,)) | |
print(im.shape, imp.shape) | |
fig, ax = pl.subplots(2,2) | |
ax[0,0].imshow(raw_im) | |
ax[0,1].imshow(im) | |
ax[1,0].imshow(imp) | |
# Overlay images | |
ax[1,1].imshow(im) | |
ax[1,1].imshow(imp, alpha=0.6) | |
pl.show() | |
if __name__ == "__main__": | |
with open("imagenet_1k.json") as fp: | |
imagenet_1k_labels = json.load(fp) | |
th.manual_seed(17) | |
np.random.seed(17) | |
vgg = CAMVGG() | |
vgg.eval() | |
for k in range(1, 20): | |
run_grad_cam("~/Pictures/flock.jpg", vgg, imagenet_1k_labels, topk=k) |
Author
ecolss
commented
Oct 14, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment