Last active
June 12, 2021 09:41
-
-
Save BeBeBerr/5af065430dece675f2b585f260108998 to your computer and use it in GitHub Desktop.
Grad-CAM with MobileNet v2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
from torchvision import datasets, transforms, models | |
from PIL import Image | |
import torch.nn.functional as F | |
import matplotlib.pyplot as plt | |
import cv2 | |
class GradCAM(): | |
def __init__(self, model, layer_index=-6): | |
# -6 is the last conv2d layer for mobilenet v2 | |
self.model = model | |
self.layer_index = layer_index | |
self.register_hooks() | |
def _forward_hook(self, module, input, output): | |
self.feature_map = output | |
def _backward_hook(self, module, grad_input, grad_output): | |
self.feature_map_grad = grad_output[0] # grad_output is a tensor | |
def register_hooks(self): | |
_, layer = list(self.model.named_modules())[self.layer_index] | |
layer.register_forward_hook(self._forward_hook) | |
layer.register_backward_hook(self._backward_hook) | |
def __call__(self, prediction, class_index): | |
self.model.zero_grad() | |
score = prediction[0, class_index] | |
score.backward() | |
alpha = self.feature_map_grad.mean(dim=(-1, -2), keepdim=True) | |
heatmap = self.feature_map * alpha | |
heatmap = heatmap.sum(1) | |
heatmap = F.relu(heatmap) | |
return heatmap | |
def main(): | |
model = models.MobileNetV2(num_classes=102) | |
checkpoint = torch.load('checkpoints/' + 'baseline.pth.tar') | |
model_dict = checkpoint['state_dict'] | |
model.load_state_dict(model_dict) | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
image_origin = Image.open('image_0042.jpg') | |
image = transform(image_origin) | |
image = torch.unsqueeze(image, 0) | |
grad_cam = GradCAM(model) | |
model.eval() | |
output = model(image) | |
index = output.argmax() | |
heatmap = grad_cam(output, index)[0].detach().numpy() | |
plt.imsave('heatmap_small.jpg', heatmap, cmap='rainbow') | |
heatmap = cv2.resize(heatmap, (224, 224), interpolation=cv2.INTER_CUBIC) | |
plt.imsave('heatmap.jpg', heatmap, cmap='rainbow') | |
heatmap_image = Image.open('heatmap.jpg') | |
heatmap_image = heatmap_image.resize((224, 224)) | |
image_origin = image_origin.resize((224, 224)) | |
blend = Image.blend(image_origin, heatmap_image, 0.5) | |
blend.save('blend.jpg') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment