Skip to content

Instantly share code, notes, and snippets.

@MartinWeiss12
Last active November 18, 2024 15:11
Show Gist options
  • Save MartinWeiss12/4f13705d8d758b229eea076eb297f7a4 to your computer and use it in GitHub Desktop.
Save MartinWeiss12/4f13705d8d758b229eea076eb297f7a4 to your computer and use it in GitHub Desktop.
Grad-CAM
def grad_cam(model, threshold, image_id, img_path, lesion_type, image_width, image_height, device):
img = cv2.imread(img_path)
img = cv2.resize(img, (image_width, image_height))
img_original = cv2.resize(img.copy(), (450, 450))
img = img / 255.0
img = img - np.array([0.5, 0.5, 0.5])
img = img / np.array([0.5, 0.5, 0.5])
img = np.transpose(img, (2, 0, 1))
img = torch.from_numpy(img).unsqueeze(0).float().to(device)
model.eval()
features = []
gradients = []
def forward_hook(module, input, output):
features.append(output.detach())
def backward_hook(module, grad_in, grad_out):
gradients.append(grad_out[0].detach())
last_conv_layer = None
for child in model.modules():
if isinstance(child, nn.Conv2d):
last_conv_layer = child
if last_conv_layer is not None:
handle_forward = last_conv_layer.register_forward_hook(forward_hook)
handle_backward = last_conv_layer.register_full_backward_hook(backward_hook)
logits = model(img)
probs = F.softmax(logits, dim=1)
predicted_class = probs.argmax().item()
model.zero_grad()
class_score = logits[0, predicted_class]
class_score.backward()
grad = gradients[0]
feature_maps = features[0]
handle_forward.remove()
handle_backward.remove()
weights = torch.mean(grad, dim=(2, 3))
weights = weights.view(-1, 1, 1)
cam = torch.sum(weights * feature_maps[0], dim=0)
cam = F.relu(cam)
cam = cam.cpu().numpy()
cam = cv2.resize(cam, (image_width, image_height))
cam = (cam - cam.min()) / (cam.max() - cam.min())
cam[cam < threshold] = 0
kernel_size = (5, 5)
sigma = 5.0
cam_smooth = cv2.GaussianBlur(cam, kernel_size, sigma)
cam_smooth = cv2.resize(cam_smooth, (450, 450))
plt.figure(figsize=(15, 5), dpi=300)
plt.subplot(1, 3, 1)
plt.imshow(cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB))
plt.title(f'Original\n{lesion_type} - 00{image_id}')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB))
plt.imshow(cam_smooth, cmap=plt.cm.jet, alpha=0.35)
plt.title(f'Grad-CAM\n{lesion_type} - 00{image_id}')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(np.zeros((1, 1)), cmap=plt.cm.jet)
plt.gca().set_visible(False)
norm = colors.Normalize(vmin=0, vmax=1)
sm = cm.ScalarMappable(cmap=plt.cm.jet, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=plt.gca(), orientation='vertical', fraction=1)
cbar.set_ticks([0.0, 1.0])
cbar.set_ticklabels(['0.0', '1.0'])
cbar.ax.yaxis.set_ticks_position('left')
cbar.ax.yaxis.set_label_position('left')
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment