Skip to content

Instantly share code, notes, and snippets.

@Karol-G
Last active May 30, 2020 07:48
Show Gist options
  • Save Karol-G/29a63098b07b79b6cbfad2f8e8a69da4 to your computer and use it in GitHub Desktop.
Save Karol-G/29a63098b07b79b6cbfad2f8e8a69da4 to your computer and use it in GitHub Desktop.
Grad-Cam with Pythia
import sys
sys.path.append('pythia')
sys.path.append('vqa-maskrcnn-benchmark')
import cv2
import torch
import gc
import pythia_grad_cam as pgc
from torch.utils.data import DataLoader
import time
from pythia_dataset import VQA_Dataset
from gradcam_utils import *
import numpy as np
torch.backends.cudnn.enabled = False
DEVICE = "cuda:0"
#DEVICE = "cpu"
TARGET_IMAGE_SIZE = [448, 448]
CHANNEL_MEAN = [0.485, 0.456, 0.406]
CHANNEL_STD = [0.229, 0.224, 0.225]
IMAGE_SHAPE = torch.Size(TARGET_IMAGE_SIZE)
use_set = 0
use_pythia = True
if use_set == 0:
# Train Set
dataType = "train"
dataSubType = dataType + '2014'
annFile = '/visinf/projects_students/shared_vqa/vqa2/raw/annotations/mscoco_train2014_annotations.json'
quesFile = '/visinf/projects_students/shared_vqa/vqa2/raw/annotations/OpenEnded_mscoco_train2014_questions.json'
imgDir = '/visinf/projects_students/shared_vqa/mscoco/train2014/'
evalIds = "evaluated_train_ids.npy"
elif use_set == 1:
# Validation Set
dataType = "val"
dataSubType = dataType + '2014'
annFile = '/visinf/projects_students/shared_vqa/vqa2/raw/annotations/mscoco_val2014_annotations.json'
quesFile = '/visinf/projects_students/shared_vqa/vqa2/raw/annotations/OpenEnded_mscoco_val2014_questions.json'
imgDir = '/visinf/projects_students/shared_vqa/mscoco/val2014/'
evalIds = "evaluated_val_ids.npy"
elif use_set == 2:
# Train Set
dataType = "train"
dataSubType = dataType + '2014'
annFile = '/visinf/projects_students/shared_vqa/vqa1/mscoco_train2014_annotations.json'
quesFile = '/visinf/projects_students/shared_vqa/vqa1/OpenEnded_mscoco_train2014_questions.json'
imgDir = '/visinf/projects_students/shared_vqa/mscoco/train2014/'
evalIds = "evaluated_vqa1_train_ids.npy"
if use_pythia:
from pythia_model import PythiaVQA as Model
model_dir = "pythia_vqa1"
else:
from ban_model import BanVQA as Model
model_dir = "ban"
def print_progress(start, j, dataset_len):
progress = ((j + 1) / dataset_len) * 100
elapsed = time.time() - start
time_per_annotation = elapsed / (j + 1)
finished_in = time_per_annotation * (dataset_len - (j + 1))
day = finished_in // (24 * 3600)
finished_in = finished_in % (24 * 3600)
hour = finished_in // 3600
finished_in %= 3600
minutes = finished_in // 60
finished_in %= 60
seconds = finished_in
print("Iteration: {} | Progress: {}% | Finished in: {}d {}h {}m {}s | Time Per Annotation: {}s".format(j, round(
progress, 6), round(day), round(hour), round(minutes), round(seconds), round(time_per_annotation, 2)))
def predict():
splits = 4
subset = 3
checkpoint = 0
vqa_dataset = VQA_Dataset(evalIds, annFile, quesFile, imgDir, dataSubType, TARGET_IMAGE_SIZE, CHANNEL_MEAN, CHANNEL_STD, splits=splits, subset=subset, checkpoint=checkpoint)
dataset_len = vqa_dataset.__len__()
print("subset: {}".format(subset))
print("checkpoint: {}".format(checkpoint))
print("vqa_dataset size: {}".format(dataset_len))
with torch.enable_grad():
layer = 'resnet152_model.7'
#layer = 'detection_model.backbone.fpn.fpn_layer4'
#layer = 'detection_model.backbone.body.layer4.2.conv3'
vqa_model = Model(DEVICE)
vqa_model.eval()
# print_layer_names(vqa_model, full=True)
# import sys; sys.exit()
vqa_model_GCAM = pgc.GradCAM(model=vqa_model, candidate_layers=[layer])
# vqa_model_GCAM = pgc.GuidedBackPropagation(model=vqa_model) #TODO: GuidedBackPropagation REVERT
#vqa_model_GCAM = vqa_model
# vqa_model_GCAM = pgc.GradCAM(model=vqa_model)
# vqa_model_GCAM.eval()
#pythia_vqa_GBP = pgc.GuidedBackPropagation(model=vqa_model)
data_loader = DataLoader(vqa_dataset, batch_size=1, shuffle=False)
start = time.time()
results = []
result_index = 0
answer_dir = model_dir + "/" + model_dir + "_answers_" + dataType + "/" + model_dir + "_pred_subset_"
for j, batch in enumerate(data_loader):
print_progress(start, checkpoint+j, dataset_len)
annId = batch['annId'].item()
question = batch['question'][0]
raw_image = batch['raw_image'].squeeze()
raw_image = raw_image.cpu().numpy()
raw_image = cv2.resize(raw_image, tuple(TARGET_IMAGE_SIZE))
#actual, indices = vqa_model.forward(batch)
actual, indices = vqa_model_GCAM.forward(batch, IMAGE_SHAPE)
# actual, indices = vqa_model_GCAM.forward(batch) #TODO: GuidedBackPropagation REVERT
top_indices = indices[0]
top_scores = actual[0]
probs = []
answers = []
for idx, score in enumerate(top_scores):
probs.append(score.item())
answers.append(
vqa_model.answer_processor.idx2word(top_indices[idx].item())
)
#self.pythia_model_GBP.backward(ids=indices[:, [0]])
#attention_map_GBP = self.resnet152_model_GBP.generate()
results.append([annId, answers[0], probs[0]])
vqa_model_GCAM.backward(ids=indices[:, [0]])
attention_map_GradCAM = vqa_model_GCAM.generate(target_layer=layer)
# vqa_model_GCAM.backward() #TODO: GuidedBackPropagation REVERT
# attention_map_GradCAM = vqa_model_GCAM.generate() #TODO: GuidedBackPropagation REVERT
attention_map_GradCAM = attention_map_GradCAM.squeeze().cpu().numpy()
# img = image[j].squeeze().detach().cpu().numpy().transpose(1, 2, 0)
save_attention_map(filename="/visinf/projects_students/shared_vqa/" + model_dir + "/attention_maps/GBP/" + dataType + "2014/" + str(annId) + ".npy", attention_map=attention_map_GradCAM)
# save_attention_map_plain(filename="attention_map.txt", attention_map=attention_map)
# save_gcam(filename="/visinf/projects_students/shared_vqa/" + model_dir + "/attention_overlay/" + dataType + "2014/" + str(j) + "_" + str(annId) + "_" + str(question) + "_" + str(answers[0]) + ".png", gcam=attention_map_GradCAM, raw_image=raw_image)
if len(results) >= 1000:
np.save("/visinf/projects_students/shared_vqa/" + answer_dir + str(
subset) + "_part_" + str(result_index) + ".npy", np.asarray(results))
results = []
result_index += 1
gc.collect()
torch.cuda.empty_cache()
np.save(
"/visinf/projects_students/shared_vqa/" + answer_dir + str(subset) + "_part_" + str(
result_index) + ".npy",
np.asarray(results))
# predictions = np.asarray(predictions)
# np.save("/visinf/projects_students/shared_vqa/pythia/attention_maps/val2014.npy", predictions)
def print_layer_names(model, full=False):
if not full:
print(list(model.named_modules())[0])
else:
print(*list(model.named_modules()), sep='\n')
if __name__ == "__main__":
predict()
datasets: vqa2
log_foldername: vqa_vqa2_pythia_1234
model: pythia
model_attributes:
pythia:
classifier:
params:
img_hidden_dim: 5000
text_hidden_dim: 300
type: logit
image_feature_dim: 2048
image_feature_embeddings:
- modal_combine:
params:
dropout: 0
hidden_dim: 5000
type: non_linear_element_multiply
normalization: softmax
transform:
params:
out_dim: 1
type: linear
image_feature_encodings:
- params:
bias_file: detectron/fc6/fc7_b.pkl
model_data_dir: ../data/
weights_file: detectron/fc6/fc7_w.pkl
type: finetune_faster_rcnn_fpn_fc7
- params:
model_data_dir: ../data/
type: default
image_text_modal_combine:
params:
dropout: 0
hidden_dim: 5000
type: non_linear_element_multiply
losses:
- type: logit_bce
metrics:
- type: vqa_accuracy
model: pythia
model_data_dir: ../data/
text_embeddings:
- params:
conv1_out: 512
conv2_out: 2
dropout: 0
embedding_dim: 300
hidden_dim: 1024
kernel_size: 1
num_layers: 1
padding: 0
type: attention
pythia_image_only:
classifier:
params:
img_hidden_dim: 5000
text_hidden_dim: 300
type: logit
image_feature_dim: 2048
image_feature_embeddings:
- modal_combine:
params:
dropout: 0
hidden_dim: 5000
type: non_linear_element_multiply
normalization: softmax
transform:
params:
out_dim: 1
type: linear
image_feature_encodings:
- params:
bias_file: detectron/fc6/fc7_b.pkl
model_data_dir: ../data/
weights_file: detectron/fc6/fc7_w.pkl
type: finetune_faster_rcnn_fpn_fc7
- params:
model_data_dir: ../data/
type: default
image_text_modal_combine:
params:
dropout: 0
hidden_dim: 5000
type: non_linear_element_multiply
losses:
- type: logit_bce
metrics:
- type: vqa_accuracy
model_data_dir: ../data/
text_embeddings:
- params:
conv1_out: 512
conv2_out: 2
dropout: 0
embedding_dim: 300
hidden_dim: 1024
kernel_size: 1
num_layers: 1
padding: 0
type: attention
pythia_question_only:
classifier:
params:
img_hidden_dim: 5000
text_hidden_dim: 300
type: logit
image_feature_dim: 2048
image_feature_embeddings:
- modal_combine:
params:
dropout: 0
hidden_dim: 5000
type: non_linear_element_multiply
normalization: softmax
transform:
params:
out_dim: 1
type: linear
image_feature_encodings:
- params:
bias_file: detectron/fc6/fc7_b.pkl
model_data_dir: ../data/
weights_file: detectron/fc6/fc7_w.pkl
type: finetune_faster_rcnn_fpn_fc7
- params:
model_data_dir: ../data/
type: default
image_text_modal_combine:
params:
dropout: 0
hidden_dim: 5000
type: non_linear_element_multiply
losses:
- type: logit_bce
metrics:
- type: vqa_accuracy
model_data_dir: ../data/
text_embeddings:
- params:
conv1_out: 512
conv2_out: 2
dropout: 0
embedding_dim: 300
hidden_dim: 1024
kernel_size: 1
num_layers: 1
padding: 0
type: attention
optimizer_attributes:
params:
eps: 1e-08
lr: 0.01
weight_decay: 0
type: Adamax
task_attributes:
vqa:
dataset_attributes:
vqa2:
data_root_dir: ../data
fast_read: False
features_max_len: 100
image_depth_first: False
image_features:
test:
- coco/detectron_fix_100/fc6/test2015,coco/resnet152/test2015
train:
- coco/detectron_fix_100/fc6/train_val_2014,coco/resnet152/train_val_2014
- coco/detectron_fix_100/fc6/train_val_2014,coco/resnet152/train_val_2014
val:
- coco/detectron_fix_100/fc6/train_val_2014,coco/resnet152/train_val_2014
imdb_files:
test:
- imdb/vqa/imdb_test2015.npy
train:
- imdb/vqa/imdb_train2014.npy
- imdb/vqa/imdb_val2014.npy
val:
- imdb/vqa/imdb_minival2014.npy
processors:
answer_processor:
params:
num_answers: 10
preprocessor:
params: None
type: simple_word
vocab_file: vocabs/answers_vqa.txt
type: vqa_answer
bbox_processor:
params:
max_length: 50
type: bbox
context_processor:
params:
max_length: 50
model_file: .vector_cache/wiki.en.bin
type: fasttext
ocr_token_processor:
params: None
type: simple_word
text_processor:
params:
max_length: 14
vocab:
embedding_name: glove.6B.300d
type: intersected
vocab_file: vocabs/vocabulary_100k.txt
preprocessor:
type: simple_sentence
params: {}
type: vocab
return_info: True
use_ocr: False
use_ocr_info: False
dataset_size_proportional_sampling: True
dataset_type: test
datasets: vqa2
tasks: vqa
training_parameters:
batch_size: 128
clip_gradients: True
clip_norm_mode: all
data_parallel: False
device: cuda
distributed: False
evalai_inference: False
experiment_name: run
load_pretrained: False
local_rank: None
log_dir: ./logs
log_interval: 100
logger_level: info
lr_ratio: 0.1
lr_scheduler: True
lr_steps:
- 15000
- 18000
- 20000
- 21000
max_epochs: None
max_grad_l2_norm: 0.25
max_iterations: 22000
metric_minimize: False
monitored_metric: vqa_accuracy
num_workers: 0
patience: 4000
pin_memory: False
pretrained_mapping: None
resume: False
resume_file: /private/home/asg/pythia/pythia/checkpoint/apr22/pythia_vqa_train_val/vqa_vqa2_pythia_1234/pythia_train_val.pth
run_type: inference
save_dir: ./save
seed: 1234
should_early_stop: False
should_not_log: False
snapshot_interval: 1000
task_size_proportional_sampling: True
use_warmup: True
verbose_dump: False
warmup_factor: 0.2
warmup_iterations: 1000
import sys
sys.path.append('pythia')
sys.path.append('vqa-maskrcnn-benchmark')
import cv2
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from vqaTools.vqa import VQA
class VQA_Dataset():
def __init__(self, evalIds, annFile, quesFile, imgDir, dataSubType, target_image_size, channel_mean, channel_std, splits, subset, checkpoint):
# initialize VQA api for QA annotations
self.vqa = VQA(annFile, quesFile)
self.annIds = self.vqa.getQuesIds()
self.annIds = self._get_evaluated_ids(evalIds)
self.splits = splits
self.subset = subset
self.checkpoint = checkpoint
self.subset_size = int(len(self.annIds) / self.splits)
self.imgDir = imgDir
self.dataSubType = dataSubType
self.target_image_size = target_image_size
self.channel_mean = channel_mean
self.channel_std = channel_std
def __len__(self):
if self.subset < self.splits - 1:
return self.subset_size - self.checkpoint
else:
left_over = len(self.annIds) - (self.subset_size * self.splits)
return self.subset_size + left_over - self.checkpoint
def _image_transform(self, img):
im = np.array(img).astype(np.float32)
im = im[:, :, ::-1]
im -= np.array([102.9801, 115.9465, 122.7717])
im_shape = im.shape
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
im_scale = float(800) / float(im_size_min)
# Prevent the biggest axis from being more than max_size
if np.round(im_scale * im_size_max) > 1333:
im_scale = float(1333) / float(im_size_max)
im = cv2.resize(
im,
None,
None,
fx=im_scale,
fy=im_scale,
interpolation=cv2.INTER_LINEAR
)
img = torch.from_numpy(im).permute(2, 0, 1)
return img, im_scale
def __getitem__(self, index):
index = self.subset * self.subset_size + index + self.checkpoint
annId = int(self.annIds[index])
#annId = index
ann = self.vqa.loadQA(annId)[0]
imgId = ann['image_id']
imgFilename = 'COCO_' + self.dataSubType + '_' + str(imgId).zfill(12) + '.jpg'
question = self.vqa.getQuestion(ann)
image_path = self.imgDir + imgFilename
img = Image.open(image_path)
raw_image = cv2.imread(image_path)
resnet_img = img.convert("RGB")
data_transforms = transforms.Compose([
transforms.Resize(self.target_image_size),
transforms.ToTensor(),
transforms.Normalize(self.channel_mean, self.channel_std),
])
resnet_img = data_transforms(resnet_img)
if len(np.shape(img)) == 2:
img = img.convert("RGB")
detectron_img, detectron_scale = self._image_transform(img)
return {"annId": annId, "question": question, "resnet_img": resnet_img, "detectron_img": detectron_img, "detectron_scale": detectron_scale, "raw_image": raw_image}
def _get_evaluated_ids(self, path):
ids = np.load(path)
ids = np.squeeze(ids[:, :1].astype(np.int32))
return ids
#!/usr/bin/env python
# coding: utf-8
#
# Author: Kazuto Nakashima
# URL: http://kazuto1011.github.io
# Created: 2017-05-26
from collections import OrderedDict, Sequence
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from tqdm import tqdm
import sys
from maskrcnn_benchmark.structures.bounding_box import BoxList
import copy
class _BaseWrapper(object):
"""
Please modify forward() and backward() according to your task.
"""
def __init__(self, model):
super(_BaseWrapper, self).__init__()
self.device = next(model.parameters()).device
self.model = model
self.handlers = [] # a set of hook function handlers
def _encode_one_hot(self, ids):
one_hot = torch.zeros_like(self.logits).to(self.device)
one_hot.scatter_(1, ids, 1.0)
return one_hot
def forward(self, image):
"""
Simple classification
"""
self.model.zero_grad()
self.logits = self.model.forward(image)
#self.logits = self.model(image)["scores"]
#return self.logits
self.probs = F.softmax(self.logits, dim=1)
return self.probs.sort(dim=1, descending=True)
def backward(self, ids):
"""
Class-specific backpropagation
Either way works:
1. self.logits.backward(gradient=one_hot, retain_graph=True)
2. (self.logits * one_hot).sum().backward(retain_graph=True)
"""
#print("backward")
one_hot = self._encode_one_hot(ids)
self.logits.backward(gradient=one_hot, retain_graph=True)
def generate(self):
raise NotImplementedError
def remove_hook(self):
"""
Remove all the forward/backward hook functions
"""
for handle in self.handlers:
handle.remove()
class BackPropagation(_BaseWrapper):
def forward(self, image):
image["resnet_img"].requires_grad_(True)
image["detectron_img"].requires_grad_(True)
image["detectron_scale"].requires_grad_(True)
self.image = image#.requires_grad_()
return super(BackPropagation, self).forward(self.image)
def generate(self):
gradient = self.image.grad.clone()
self.image.grad.zero_()
return gradient
class GuidedBackPropagation(BackPropagation):
"""
"Striving for Simplicity: the All Convolutional Net"
https://arxiv.org/pdf/1412.6806.pdf
Look at Figure 1 on page 8.
"""
def __init__(self, model):
super(GuidedBackPropagation, self).__init__(model)
def backward_hook(module, grad_in, grad_out):
# Cut off negative gradients
if isinstance(module, nn.ReLU):
return (torch.clamp(grad_in[0], min=0.0),)
for module in self.model.named_modules():
self.handlers.append(module[1].register_backward_hook(backward_hook))
#pass
class Deconvnet(BackPropagation):
"""
"Striving for Simplicity: the All Convolutional Net"
https://arxiv.org/pdf/1412.6806.pdf
Look at Figure 1 on page 8.
"""
def __init__(self, model):
super(Deconvnet, self).__init__(model)
def backward_hook(module, grad_in, grad_out):
# Cut off negative gradients and ignore ReLU
if isinstance(module, nn.ReLU):
return (torch.clamp(grad_out[0], min=0.0),)
for module in self.model.named_modules():
self.handlers.append(module[1].register_backward_hook(backward_hook))
def detach_output(output, depth):
if not isinstance(output, torch.Tensor) and not isinstance(output, dict) and not isinstance(output, BoxList):
tuple_list = []
for item in output:
if isinstance(output, torch.Tensor):
#print("output: ", output.shape)
tuple_list.append(item.detach())
else:
tuple_list.append(detach_output(item, depth+1))
return tuple_list
elif isinstance(output, dict) or isinstance(output, BoxList):
#print("Dict keys: ", output.keys())
return output
#print("output: ", output.shape)
return output.detach()
class GradCAM(_BaseWrapper):
"""
"Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization"
https://arxiv.org/pdf/1610.02391.pdf
Look at Figure 2 on page 4
"""
def __init__(self, model, candidate_layers=None):
super(GradCAM, self).__init__(model)
self.fmap_pool = OrderedDict()
self.grad_pool = OrderedDict()
self.candidate_layers = candidate_layers # list
def forward_hook(key):
def forward_hook_(module, input, output):
# print("forward_hook_:")
# Save featuremaps
print("key forward: ", key)
#output.register_hook(backward_hook(key))
output = detach_output(output, 0)
if not isinstance(output, dict) or not isinstance(output, BoxList):
self.fmap_pool[key] = output
return forward_hook_
def backward_hook(key):
#print("key2: ", key)
def backward_hook_(module, grad_in, grad_out):
# Save the gradients correspond to the featuremaps
print("key backward: ", key)
self.grad_pool[key] = grad_out[0].detach()
return backward_hook_
#print(model.resnet152_model._modules.get("0"))
#model.resnet152_model.layer4[0].conv2.register_backward_hook(backward_hook)
# If any candidates are not specified, the hook is registered to all the layers.
for name, module in self.model.named_modules():
if self.candidate_layers is None or name in self.candidate_layers:
#print("name: ", name)
self.handlers.append(module.register_forward_hook(forward_hook(name)))
#print("name: ", name)
#module.retain_grad()
#module.require_grad = True
#print(name)
self.handlers.append(module.register_backward_hook(backward_hook(name)))
#print("self.handlers: ", self.handlers)
# self.handlers.append(model.resnet152_model._modules.get("0").register_backward_hook(backward_hook('resnet152_model.0')))
# my_layer = model.resnet152_model._modules.get("7")
# my_layer_name = 'detection_model.backbone.fpn.fpn_layer4'
# self.handlers.append(my_layer.register_backward_hook(backward_hook(my_layer_name)))
def _find(self, pool, target_layer):
#print(pool.keys())
#print(list(self.model.named_modules()))
if target_layer in pool.keys():
return pool[target_layer]
else:
raise ValueError("Invalid layer name: {}".format(target_layer))
def _compute_grad_weights(self, grads):
return F.adaptive_avg_pool2d(grads, 1)
def forward(self, sample_list, image_shape):
self.image_shape = image_shape
return super(GradCAM, self).forward(sample_list)
# def select_highest_layer(self):
# fmap_list, weight_list = [], []
# module_names = []
# for name, _ in self.model.named_modules():
# module_names.append(name)
# module_names.reverse()
#
# for i in range(self.logits.shape[0]):
# counter = 0
# for layer in module_names:
# try:
# print("Testing layer: {}".format(layer))
# fmaps = self._find(self.fmap_pool, layer)
# print("1")
# np.shape(fmaps) # Throws error without this line, I have no idea why...
# print("2")
# fmaps = fmaps[i]
# print("3")
# grads = self._find(self.grad_pool, layer)[i]
# print("4")
# import array_check
# array_check.check(fmaps)
# array_check.check(grads)
# # print("counter: {}".format(counter))
# # print("fmaps shape: {}".format(np.shape(fmaps)))
# # print("grads shape: {}".format(np.shape(grads)))
# nonzeros = np.count_nonzero(grads.detach().cpu().numpy())
# # if True: #counter < 100:
# # print("counter: {}".format(counter))
# # #print("fmaps: {}".format(fmaps))
# # print("nonzeros: {}".format(nonzeros))
# # print("fmaps shape: {}".format(np.shape(fmaps)))
# # print("grads shape: {}".format(np.shape(grads)))
# self._compute_grad_weights(grads)
# if nonzeros == 0 or not isinstance(fmaps, torch.Tensor) or not isinstance(grads, torch.Tensor):
# counter += 1
# print("Skipped layer: {}".format(layer))
# continue
# print("Dismissed the last {} module layers (Note: This number can be inflated if the model contains many nested module layers)".format(counter))
# print("Selected module layer: {}".format(layer))
# fmap_list.append(self._find(self.fmap_pool, layer)[i])
# grads = self._find(self.grad_pool, layer)[i]
# weight_list.append(self._compute_grad_weights(grads))
# break
# except ValueError:
# counter += 1
# except RuntimeError:
# counter += 1
#
# return fmap_list, weight_list
#
# def generate_helper(self, fmaps, weights):
# gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True)
# gcam = F.relu(gcam)
#
# gcam = F.interpolate(
# gcam, self.image_shape, mode="bilinear", align_corners=False
# )
#
# B, C, H, W = gcam.shape
# gcam = gcam.view(B, -1)
# gcam -= gcam.min(dim=1, keepdim=True)[0]
# gcam /= gcam.max(dim=1, keepdim=True)[0]
# gcam = gcam.view(B, C, H, W)
#
# return gcam
#
# def generate(self, target_layer, dim=2):
# if target_layer == "auto":
# fmaps, weights = self.select_highest_layer()
# gcam = []
# for i in range(self.logits.shape[0]):
# gcam.append(self.generate_helper(fmaps[i].unsqueeze(0), weights[i].unsqueeze(0)))
# else:
# fmaps = self._find(self.fmap_pool, target_layer)
# grads = self._find(self.grad_pool, target_layer)
# weights = self._compute_grad_weights(grads)
# gcam_tensor = self.generate_helper(fmaps, weights)
# gcam = []
# for i in range(self.logits.shape[0]):
# tmp = gcam_tensor[i].unsqueeze(0)
# gcam.append(tmp)
# return gcam
def generate(self, target_layer):
fmaps = self._find(self.fmap_pool, target_layer)
grads = self._find(self.grad_pool, target_layer)
# fmaps = fmaps.unsqueeze(2).unsqueeze(3)
# grads = grads.unsqueeze(2).unsqueeze(3)
# if len(fmaps.shape) == 2:
# fmaps = fmaps.unsqueeze(2)
# grads = grads.unsqueeze(2)
# if len(fmaps.shape) == 3:
# fmaps = fmaps.unsqueeze(3)
# grads = grads.unsqueeze(3)
# print("fmaps.shape: ", fmaps.shape)
# print("grads.shape: ", grads.shape)
weights = self._compute_grad_weights(grads)
# print("weights.shape: ", weights.shape)
# print("fmaps.shape: ", fmaps.shape)
# print("grads.shape: ", grads.shape)
# print("weights.shape: ", weights.shape)
gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True)
# print("gcam.shape: ", gcam.shape)
gcam = F.relu(gcam)
# print("gcam.shape: ", gcam.shape)
gcam = F.interpolate(
gcam, self.image_shape, mode="bilinear", align_corners=False
)
# print("gcam.shape: ", gcam.shape)
B, C, H, W = gcam.shape
gcam = gcam.view(B, -1)
gcam -= gcam.min(dim=1, keepdim=True)[0]
gcam /= gcam.max(dim=1, keepdim=True)[0]
gcam = gcam.view(B, C, H, W)
return gcam
def occlusion_sensitivity(
model, batch, ids, mean=None, patch=35, stride=1, n_batches=128
):
"""
"Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization"
https://arxiv.org/pdf/1610.02391.pdf
Look at Figure A5 on page 17
Originally proposed in:
"Visualizing and Understanding Convolutional Networks"
https://arxiv.org/abs/1311.2901
"""
torch.set_grad_enabled(False)
model.eval()
mean = mean if mean else 0
patch_H, patch_W = patch if isinstance(patch, Sequence) else (patch, patch)
pad_H, pad_W = patch_H // 2, patch_W // 2
# Padded image
batch["raw_image"] = F.pad(batch["raw_image"], (pad_W, pad_W, pad_H, pad_H), value=mean)
B, _, H, W = batch["raw_image"].shape
new_H = (H - patch_H) // stride + 1
new_W = (W - patch_W) // stride + 1
# Prepare sampling grids
anchors = []
grid_h = 0
while grid_h <= H - patch_H:
grid_w = 0
while grid_w <= W - patch_W:
grid_w += stride
anchors.append((grid_h, grid_w))
grid_h += stride
# Baseline score without occlusion
baseline = model(batch).detach().gather(1, ids)
# Compute per-pixel logits
scoremaps = []
for i in tqdm(range(0, len(anchors), n_batches), leave=False):
#print("Test 1")
# batches = []
# batch_ids = []
for grid_h, grid_w in tqdm(anchors[i : i + n_batches]):
#print("Test 2")
batch_ = _batch_clone(batch) #batch.clone()
batch_["raw_image"][..., grid_h : grid_h + patch_H, grid_w : grid_w + patch_W] = mean
score = model(batch_).detach().gather(1, ids)
scoremaps.append(score)
# batches.append(batch_)
# batch_ids.append(ids)
# batches = _batch_cat(batches) #torch.cat(batches, dim=0)
# batch_ids = torch.cat(batch_ids, dim=0)
# scores = model(batches).detach().gather(1, batch_ids)
# scoremaps += list(torch.split(scores, B))
diffmaps = torch.cat(scoremaps, dim=1) - baseline
diffmaps = diffmaps.view(B, new_H, new_W)
return diffmaps
def _batch_clone(batch):
clone = {}
for key in batch.keys():
if isinstance(batch[key], torch.Tensor):
clone[key] = batch[key].clone()
else:
clone[key] = copy.deepcopy(batch[key])
return clone
def _batch_cat(batch_list):
cat_batch = {}
for key in batch_list[0].keys():
cat_batch[key] = [batch[key] for batch in batch_list]
if isinstance(batch_list[0][key], torch.Tensor):
cat_batch[key] = torch.cat(cat_batch[key], dim=0)
return cat_batch
import sys
sys.path.append('pythia')
sys.path.append('vqa-maskrcnn-benchmark')
import yaml
import torch
import torch.nn.functional as F
import torchvision.models as models
from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.layers import nms
from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.structures.image_list import to_image_list
from maskrcnn_benchmark.utils.model_serialization import load_state_dict
from pythia.utils.configuration import ConfigNode
from pythia.tasks.processors import VocabProcessor, VQAAnswerProcessor
from pythia.models.pythia import Pythia
from pythia.common.registry import registry
from pythia.common.sample import Sample, SampleList
torch.backends.cudnn.enabled = False
class PythiaVQA(torch.nn.Module):
def __init__(self, device):
super(PythiaVQA, self).__init__()
self.device = device
self._init_processors()
self.pythia_model = self._build_pythia_model()
self.detection_model = self._build_detection_model()
self.resnet152_model = self._build_resnet_model()
def _init_processors(self):
with open("model_data/pythia.yaml") as f:
config = yaml.load(f)
config = ConfigNode(config)
# Remove warning
config.training_parameters.evalai_inference = True
registry.register("config", config)
self.config = config
vqa_config = config.task_attributes.vqa.dataset_attributes.vqa2
text_processor_config = vqa_config.processors.text_processor
answer_processor_config = vqa_config.processors.answer_processor
text_processor_config.params.vocab.vocab_file = "model_data/vocabulary_100k.txt"
answer_processor_config.params.vocab_file = "model_data/answers_vqa.txt"
# Add preprocessor as that will needed when we are getting questions from user
self.text_processor = VocabProcessor(text_processor_config.params)
self.answer_processor = VQAAnswerProcessor(answer_processor_config.params)
registry.register("vqa2_text_processor", self.text_processor)
registry.register("vqa2_answer_processor", self.answer_processor)
registry.register("vqa2_num_final_outputs",
self.answer_processor.get_vocab_size())
def _build_pythia_model(self):
state_dict = torch.load('model_data/pythia.pth')
model_config = self.config.model_attributes.pythia
model_config.model_data_dir = "/visinf/home/vilab22/Documents/RemoteProjects/dlcv/pythia/" #/content/
model = Pythia(model_config)
model.build()
model.init_losses_and_metrics()
if list(state_dict.keys())[0].startswith('module') and \
not hasattr(model, 'module'):
state_dict = self._multi_gpu_state_to_single(state_dict)
model.load_state_dict(state_dict)
model.to(self.device)
model.eval()
return model
def _build_resnet_model(self):
resnet152 = models.resnet152(pretrained=True)
resnet152.eval()
modules = list(resnet152.children())[:-2]
resnet152_model = torch.nn.Sequential(*modules)
resnet152_model.to(self.device)
resnet152_model.eval()
return resnet152_model
def _multi_gpu_state_to_single(self, state_dict):
new_sd = {}
for k, v in state_dict.items():
if not k.startswith('module.'):
raise TypeError("Not a multiple GPU state of dict")
k1 = k[7:]
new_sd[k1] = v
return new_sd
def _build_detection_model(self):
cfg.merge_from_file('model_data/detectron_model.yaml')
cfg.freeze()
model = build_detection_model(cfg)
checkpoint = torch.load('model_data/detectron_model.pth',
map_location=torch.device("cpu"))
load_state_dict(model, checkpoint.pop("model"))
model.to(self.device)
model.eval()
return model
def _process_feature_extraction(self, output,
im_scales,
feat_name='fc6',
conf_thresh=0.2):
batch_size = len(output[0]["proposals"])
n_boxes_per_image = [len(_) for _ in output[0]["proposals"]]
score_list = output[0]["scores"].split(n_boxes_per_image)
score_list = [torch.nn.functional.softmax(x, -1) for x in score_list]
feats = output[0][feat_name].split(n_boxes_per_image)
cur_device = score_list[0].device
feat_list = []
for i in range(batch_size):
dets = output[0]["proposals"][i].bbox / im_scales[i]
scores = score_list[i]
max_conf = torch.zeros((scores.shape[0])).to(cur_device)
for cls_ind in range(1, scores.shape[1]):
cls_scores = scores[:, cls_ind]
keep = nms(dets, cls_scores, 0.5)
max_conf[keep] = torch.where(cls_scores[keep] > max_conf[keep],
cls_scores[keep],
max_conf[keep])
keep_boxes = torch.argsort(max_conf, descending=True)[:100]
feat_list.append(feats[i][keep_boxes])
return feat_list
def masked_unk_softmax(self, x, dim, mask_idx):
x1 = F.softmax(x, dim=dim)
x1[:, mask_idx] = 0
x1_sum = torch.sum(x1, dim=1, keepdim=True)
y = x1 / x1_sum
return y
def get_resnet_features(self, img):
if img.shape[0] == 1:
img = img.expand(3, -1, -1)
img = img.unsqueeze(0).to(self.device)
features = self.resnet152_model.forward(img).permute(0, 2, 3, 1)
features = features.view(196, 2048)
return features
def get_detectron_features(self, im, im_scale):
current_img_list = to_image_list(im, size_divisible=32)
current_img_list = current_img_list.to(self.device)
with torch.no_grad():
output = self.detection_model.forward(current_img_list)
feat_list = self._process_feature_extraction(output, im_scale,
'fc6', 0.2)
return feat_list[0]
def forward(self, batch):
question = batch['question'][0]
detectron_img = batch['detectron_img'].squeeze()
detectron_scale = [batch['detectron_scale'].item()]
resnet_img = batch['resnet_img'].squeeze()
detectron_features = self.get_detectron_features(detectron_img, detectron_scale)
# with torch.no_grad():
resnet_features = self.get_resnet_features(resnet_img)
sample = Sample()
processed_text = self.text_processor({"text": question})
sample.text = processed_text["text"]
sample.text_len = len(processed_text["tokens"])
sample.image_feature_0 = detectron_features
sample.image_info_0 = Sample({
"max_features": torch.tensor(100, dtype=torch.long)
})
sample.image_feature_1 = resnet_features
sample_list = SampleList([sample])
sample_list = sample_list.to(self.device)
scores = self.pythia_model.forward(sample_list)["scores"]
return scores
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment