Last active
May 30, 2020 07:48
-
-
Save Karol-G/29a63098b07b79b6cbfad2f8e8a69da4 to your computer and use it in GitHub Desktop.
Grad-Cam with Pythia
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
sys.path.append('pythia') | |
sys.path.append('vqa-maskrcnn-benchmark') | |
import cv2 | |
import torch | |
import gc | |
import pythia_grad_cam as pgc | |
from torch.utils.data import DataLoader | |
import time | |
from pythia_dataset import VQA_Dataset | |
from gradcam_utils import * | |
import numpy as np | |
torch.backends.cudnn.enabled = False | |
DEVICE = "cuda:0" | |
#DEVICE = "cpu" | |
TARGET_IMAGE_SIZE = [448, 448] | |
CHANNEL_MEAN = [0.485, 0.456, 0.406] | |
CHANNEL_STD = [0.229, 0.224, 0.225] | |
IMAGE_SHAPE = torch.Size(TARGET_IMAGE_SIZE) | |
use_set = 0 | |
use_pythia = True | |
if use_set == 0: | |
# Train Set | |
dataType = "train" | |
dataSubType = dataType + '2014' | |
annFile = '/visinf/projects_students/shared_vqa/vqa2/raw/annotations/mscoco_train2014_annotations.json' | |
quesFile = '/visinf/projects_students/shared_vqa/vqa2/raw/annotations/OpenEnded_mscoco_train2014_questions.json' | |
imgDir = '/visinf/projects_students/shared_vqa/mscoco/train2014/' | |
evalIds = "evaluated_train_ids.npy" | |
elif use_set == 1: | |
# Validation Set | |
dataType = "val" | |
dataSubType = dataType + '2014' | |
annFile = '/visinf/projects_students/shared_vqa/vqa2/raw/annotations/mscoco_val2014_annotations.json' | |
quesFile = '/visinf/projects_students/shared_vqa/vqa2/raw/annotations/OpenEnded_mscoco_val2014_questions.json' | |
imgDir = '/visinf/projects_students/shared_vqa/mscoco/val2014/' | |
evalIds = "evaluated_val_ids.npy" | |
elif use_set == 2: | |
# Train Set | |
dataType = "train" | |
dataSubType = dataType + '2014' | |
annFile = '/visinf/projects_students/shared_vqa/vqa1/mscoco_train2014_annotations.json' | |
quesFile = '/visinf/projects_students/shared_vqa/vqa1/OpenEnded_mscoco_train2014_questions.json' | |
imgDir = '/visinf/projects_students/shared_vqa/mscoco/train2014/' | |
evalIds = "evaluated_vqa1_train_ids.npy" | |
if use_pythia: | |
from pythia_model import PythiaVQA as Model | |
model_dir = "pythia_vqa1" | |
else: | |
from ban_model import BanVQA as Model | |
model_dir = "ban" | |
def print_progress(start, j, dataset_len): | |
progress = ((j + 1) / dataset_len) * 100 | |
elapsed = time.time() - start | |
time_per_annotation = elapsed / (j + 1) | |
finished_in = time_per_annotation * (dataset_len - (j + 1)) | |
day = finished_in // (24 * 3600) | |
finished_in = finished_in % (24 * 3600) | |
hour = finished_in // 3600 | |
finished_in %= 3600 | |
minutes = finished_in // 60 | |
finished_in %= 60 | |
seconds = finished_in | |
print("Iteration: {} | Progress: {}% | Finished in: {}d {}h {}m {}s | Time Per Annotation: {}s".format(j, round( | |
progress, 6), round(day), round(hour), round(minutes), round(seconds), round(time_per_annotation, 2))) | |
def predict(): | |
splits = 4 | |
subset = 3 | |
checkpoint = 0 | |
vqa_dataset = VQA_Dataset(evalIds, annFile, quesFile, imgDir, dataSubType, TARGET_IMAGE_SIZE, CHANNEL_MEAN, CHANNEL_STD, splits=splits, subset=subset, checkpoint=checkpoint) | |
dataset_len = vqa_dataset.__len__() | |
print("subset: {}".format(subset)) | |
print("checkpoint: {}".format(checkpoint)) | |
print("vqa_dataset size: {}".format(dataset_len)) | |
with torch.enable_grad(): | |
layer = 'resnet152_model.7' | |
#layer = 'detection_model.backbone.fpn.fpn_layer4' | |
#layer = 'detection_model.backbone.body.layer4.2.conv3' | |
vqa_model = Model(DEVICE) | |
vqa_model.eval() | |
# print_layer_names(vqa_model, full=True) | |
# import sys; sys.exit() | |
vqa_model_GCAM = pgc.GradCAM(model=vqa_model, candidate_layers=[layer]) | |
# vqa_model_GCAM = pgc.GuidedBackPropagation(model=vqa_model) #TODO: GuidedBackPropagation REVERT | |
#vqa_model_GCAM = vqa_model | |
# vqa_model_GCAM = pgc.GradCAM(model=vqa_model) | |
# vqa_model_GCAM.eval() | |
#pythia_vqa_GBP = pgc.GuidedBackPropagation(model=vqa_model) | |
data_loader = DataLoader(vqa_dataset, batch_size=1, shuffle=False) | |
start = time.time() | |
results = [] | |
result_index = 0 | |
answer_dir = model_dir + "/" + model_dir + "_answers_" + dataType + "/" + model_dir + "_pred_subset_" | |
for j, batch in enumerate(data_loader): | |
print_progress(start, checkpoint+j, dataset_len) | |
annId = batch['annId'].item() | |
question = batch['question'][0] | |
raw_image = batch['raw_image'].squeeze() | |
raw_image = raw_image.cpu().numpy() | |
raw_image = cv2.resize(raw_image, tuple(TARGET_IMAGE_SIZE)) | |
#actual, indices = vqa_model.forward(batch) | |
actual, indices = vqa_model_GCAM.forward(batch, IMAGE_SHAPE) | |
# actual, indices = vqa_model_GCAM.forward(batch) #TODO: GuidedBackPropagation REVERT | |
top_indices = indices[0] | |
top_scores = actual[0] | |
probs = [] | |
answers = [] | |
for idx, score in enumerate(top_scores): | |
probs.append(score.item()) | |
answers.append( | |
vqa_model.answer_processor.idx2word(top_indices[idx].item()) | |
) | |
#self.pythia_model_GBP.backward(ids=indices[:, [0]]) | |
#attention_map_GBP = self.resnet152_model_GBP.generate() | |
results.append([annId, answers[0], probs[0]]) | |
vqa_model_GCAM.backward(ids=indices[:, [0]]) | |
attention_map_GradCAM = vqa_model_GCAM.generate(target_layer=layer) | |
# vqa_model_GCAM.backward() #TODO: GuidedBackPropagation REVERT | |
# attention_map_GradCAM = vqa_model_GCAM.generate() #TODO: GuidedBackPropagation REVERT | |
attention_map_GradCAM = attention_map_GradCAM.squeeze().cpu().numpy() | |
# img = image[j].squeeze().detach().cpu().numpy().transpose(1, 2, 0) | |
save_attention_map(filename="/visinf/projects_students/shared_vqa/" + model_dir + "/attention_maps/GBP/" + dataType + "2014/" + str(annId) + ".npy", attention_map=attention_map_GradCAM) | |
# save_attention_map_plain(filename="attention_map.txt", attention_map=attention_map) | |
# save_gcam(filename="/visinf/projects_students/shared_vqa/" + model_dir + "/attention_overlay/" + dataType + "2014/" + str(j) + "_" + str(annId) + "_" + str(question) + "_" + str(answers[0]) + ".png", gcam=attention_map_GradCAM, raw_image=raw_image) | |
if len(results) >= 1000: | |
np.save("/visinf/projects_students/shared_vqa/" + answer_dir + str( | |
subset) + "_part_" + str(result_index) + ".npy", np.asarray(results)) | |
results = [] | |
result_index += 1 | |
gc.collect() | |
torch.cuda.empty_cache() | |
np.save( | |
"/visinf/projects_students/shared_vqa/" + answer_dir + str(subset) + "_part_" + str( | |
result_index) + ".npy", | |
np.asarray(results)) | |
# predictions = np.asarray(predictions) | |
# np.save("/visinf/projects_students/shared_vqa/pythia/attention_maps/val2014.npy", predictions) | |
def print_layer_names(model, full=False): | |
if not full: | |
print(list(model.named_modules())[0]) | |
else: | |
print(*list(model.named_modules()), sep='\n') | |
if __name__ == "__main__": | |
predict() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
datasets: vqa2 | |
log_foldername: vqa_vqa2_pythia_1234 | |
model: pythia | |
model_attributes: | |
pythia: | |
classifier: | |
params: | |
img_hidden_dim: 5000 | |
text_hidden_dim: 300 | |
type: logit | |
image_feature_dim: 2048 | |
image_feature_embeddings: | |
- modal_combine: | |
params: | |
dropout: 0 | |
hidden_dim: 5000 | |
type: non_linear_element_multiply | |
normalization: softmax | |
transform: | |
params: | |
out_dim: 1 | |
type: linear | |
image_feature_encodings: | |
- params: | |
bias_file: detectron/fc6/fc7_b.pkl | |
model_data_dir: ../data/ | |
weights_file: detectron/fc6/fc7_w.pkl | |
type: finetune_faster_rcnn_fpn_fc7 | |
- params: | |
model_data_dir: ../data/ | |
type: default | |
image_text_modal_combine: | |
params: | |
dropout: 0 | |
hidden_dim: 5000 | |
type: non_linear_element_multiply | |
losses: | |
- type: logit_bce | |
metrics: | |
- type: vqa_accuracy | |
model: pythia | |
model_data_dir: ../data/ | |
text_embeddings: | |
- params: | |
conv1_out: 512 | |
conv2_out: 2 | |
dropout: 0 | |
embedding_dim: 300 | |
hidden_dim: 1024 | |
kernel_size: 1 | |
num_layers: 1 | |
padding: 0 | |
type: attention | |
pythia_image_only: | |
classifier: | |
params: | |
img_hidden_dim: 5000 | |
text_hidden_dim: 300 | |
type: logit | |
image_feature_dim: 2048 | |
image_feature_embeddings: | |
- modal_combine: | |
params: | |
dropout: 0 | |
hidden_dim: 5000 | |
type: non_linear_element_multiply | |
normalization: softmax | |
transform: | |
params: | |
out_dim: 1 | |
type: linear | |
image_feature_encodings: | |
- params: | |
bias_file: detectron/fc6/fc7_b.pkl | |
model_data_dir: ../data/ | |
weights_file: detectron/fc6/fc7_w.pkl | |
type: finetune_faster_rcnn_fpn_fc7 | |
- params: | |
model_data_dir: ../data/ | |
type: default | |
image_text_modal_combine: | |
params: | |
dropout: 0 | |
hidden_dim: 5000 | |
type: non_linear_element_multiply | |
losses: | |
- type: logit_bce | |
metrics: | |
- type: vqa_accuracy | |
model_data_dir: ../data/ | |
text_embeddings: | |
- params: | |
conv1_out: 512 | |
conv2_out: 2 | |
dropout: 0 | |
embedding_dim: 300 | |
hidden_dim: 1024 | |
kernel_size: 1 | |
num_layers: 1 | |
padding: 0 | |
type: attention | |
pythia_question_only: | |
classifier: | |
params: | |
img_hidden_dim: 5000 | |
text_hidden_dim: 300 | |
type: logit | |
image_feature_dim: 2048 | |
image_feature_embeddings: | |
- modal_combine: | |
params: | |
dropout: 0 | |
hidden_dim: 5000 | |
type: non_linear_element_multiply | |
normalization: softmax | |
transform: | |
params: | |
out_dim: 1 | |
type: linear | |
image_feature_encodings: | |
- params: | |
bias_file: detectron/fc6/fc7_b.pkl | |
model_data_dir: ../data/ | |
weights_file: detectron/fc6/fc7_w.pkl | |
type: finetune_faster_rcnn_fpn_fc7 | |
- params: | |
model_data_dir: ../data/ | |
type: default | |
image_text_modal_combine: | |
params: | |
dropout: 0 | |
hidden_dim: 5000 | |
type: non_linear_element_multiply | |
losses: | |
- type: logit_bce | |
metrics: | |
- type: vqa_accuracy | |
model_data_dir: ../data/ | |
text_embeddings: | |
- params: | |
conv1_out: 512 | |
conv2_out: 2 | |
dropout: 0 | |
embedding_dim: 300 | |
hidden_dim: 1024 | |
kernel_size: 1 | |
num_layers: 1 | |
padding: 0 | |
type: attention | |
optimizer_attributes: | |
params: | |
eps: 1e-08 | |
lr: 0.01 | |
weight_decay: 0 | |
type: Adamax | |
task_attributes: | |
vqa: | |
dataset_attributes: | |
vqa2: | |
data_root_dir: ../data | |
fast_read: False | |
features_max_len: 100 | |
image_depth_first: False | |
image_features: | |
test: | |
- coco/detectron_fix_100/fc6/test2015,coco/resnet152/test2015 | |
train: | |
- coco/detectron_fix_100/fc6/train_val_2014,coco/resnet152/train_val_2014 | |
- coco/detectron_fix_100/fc6/train_val_2014,coco/resnet152/train_val_2014 | |
val: | |
- coco/detectron_fix_100/fc6/train_val_2014,coco/resnet152/train_val_2014 | |
imdb_files: | |
test: | |
- imdb/vqa/imdb_test2015.npy | |
train: | |
- imdb/vqa/imdb_train2014.npy | |
- imdb/vqa/imdb_val2014.npy | |
val: | |
- imdb/vqa/imdb_minival2014.npy | |
processors: | |
answer_processor: | |
params: | |
num_answers: 10 | |
preprocessor: | |
params: None | |
type: simple_word | |
vocab_file: vocabs/answers_vqa.txt | |
type: vqa_answer | |
bbox_processor: | |
params: | |
max_length: 50 | |
type: bbox | |
context_processor: | |
params: | |
max_length: 50 | |
model_file: .vector_cache/wiki.en.bin | |
type: fasttext | |
ocr_token_processor: | |
params: None | |
type: simple_word | |
text_processor: | |
params: | |
max_length: 14 | |
vocab: | |
embedding_name: glove.6B.300d | |
type: intersected | |
vocab_file: vocabs/vocabulary_100k.txt | |
preprocessor: | |
type: simple_sentence | |
params: {} | |
type: vocab | |
return_info: True | |
use_ocr: False | |
use_ocr_info: False | |
dataset_size_proportional_sampling: True | |
dataset_type: test | |
datasets: vqa2 | |
tasks: vqa | |
training_parameters: | |
batch_size: 128 | |
clip_gradients: True | |
clip_norm_mode: all | |
data_parallel: False | |
device: cuda | |
distributed: False | |
evalai_inference: False | |
experiment_name: run | |
load_pretrained: False | |
local_rank: None | |
log_dir: ./logs | |
log_interval: 100 | |
logger_level: info | |
lr_ratio: 0.1 | |
lr_scheduler: True | |
lr_steps: | |
- 15000 | |
- 18000 | |
- 20000 | |
- 21000 | |
max_epochs: None | |
max_grad_l2_norm: 0.25 | |
max_iterations: 22000 | |
metric_minimize: False | |
monitored_metric: vqa_accuracy | |
num_workers: 0 | |
patience: 4000 | |
pin_memory: False | |
pretrained_mapping: None | |
resume: False | |
resume_file: /private/home/asg/pythia/pythia/checkpoint/apr22/pythia_vqa_train_val/vqa_vqa2_pythia_1234/pythia_train_val.pth | |
run_type: inference | |
save_dir: ./save | |
seed: 1234 | |
should_early_stop: False | |
should_not_log: False | |
snapshot_interval: 1000 | |
task_size_proportional_sampling: True | |
use_warmup: True | |
verbose_dump: False | |
warmup_factor: 0.2 | |
warmup_iterations: 1000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
sys.path.append('pythia') | |
sys.path.append('vqa-maskrcnn-benchmark') | |
import cv2 | |
import torch | |
import numpy as np | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from vqaTools.vqa import VQA | |
class VQA_Dataset(): | |
def __init__(self, evalIds, annFile, quesFile, imgDir, dataSubType, target_image_size, channel_mean, channel_std, splits, subset, checkpoint): | |
# initialize VQA api for QA annotations | |
self.vqa = VQA(annFile, quesFile) | |
self.annIds = self.vqa.getQuesIds() | |
self.annIds = self._get_evaluated_ids(evalIds) | |
self.splits = splits | |
self.subset = subset | |
self.checkpoint = checkpoint | |
self.subset_size = int(len(self.annIds) / self.splits) | |
self.imgDir = imgDir | |
self.dataSubType = dataSubType | |
self.target_image_size = target_image_size | |
self.channel_mean = channel_mean | |
self.channel_std = channel_std | |
def __len__(self): | |
if self.subset < self.splits - 1: | |
return self.subset_size - self.checkpoint | |
else: | |
left_over = len(self.annIds) - (self.subset_size * self.splits) | |
return self.subset_size + left_over - self.checkpoint | |
def _image_transform(self, img): | |
im = np.array(img).astype(np.float32) | |
im = im[:, :, ::-1] | |
im -= np.array([102.9801, 115.9465, 122.7717]) | |
im_shape = im.shape | |
im_size_min = np.min(im_shape[0:2]) | |
im_size_max = np.max(im_shape[0:2]) | |
im_scale = float(800) / float(im_size_min) | |
# Prevent the biggest axis from being more than max_size | |
if np.round(im_scale * im_size_max) > 1333: | |
im_scale = float(1333) / float(im_size_max) | |
im = cv2.resize( | |
im, | |
None, | |
None, | |
fx=im_scale, | |
fy=im_scale, | |
interpolation=cv2.INTER_LINEAR | |
) | |
img = torch.from_numpy(im).permute(2, 0, 1) | |
return img, im_scale | |
def __getitem__(self, index): | |
index = self.subset * self.subset_size + index + self.checkpoint | |
annId = int(self.annIds[index]) | |
#annId = index | |
ann = self.vqa.loadQA(annId)[0] | |
imgId = ann['image_id'] | |
imgFilename = 'COCO_' + self.dataSubType + '_' + str(imgId).zfill(12) + '.jpg' | |
question = self.vqa.getQuestion(ann) | |
image_path = self.imgDir + imgFilename | |
img = Image.open(image_path) | |
raw_image = cv2.imread(image_path) | |
resnet_img = img.convert("RGB") | |
data_transforms = transforms.Compose([ | |
transforms.Resize(self.target_image_size), | |
transforms.ToTensor(), | |
transforms.Normalize(self.channel_mean, self.channel_std), | |
]) | |
resnet_img = data_transforms(resnet_img) | |
if len(np.shape(img)) == 2: | |
img = img.convert("RGB") | |
detectron_img, detectron_scale = self._image_transform(img) | |
return {"annId": annId, "question": question, "resnet_img": resnet_img, "detectron_img": detectron_img, "detectron_scale": detectron_scale, "raw_image": raw_image} | |
def _get_evaluated_ids(self, path): | |
ids = np.load(path) | |
ids = np.squeeze(ids[:, :1].astype(np.int32)) | |
return ids |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding: utf-8 | |
# | |
# Author: Kazuto Nakashima | |
# URL: http://kazuto1011.github.io | |
# Created: 2017-05-26 | |
from collections import OrderedDict, Sequence | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from tqdm import tqdm | |
import sys | |
from maskrcnn_benchmark.structures.bounding_box import BoxList | |
import copy | |
class _BaseWrapper(object): | |
""" | |
Please modify forward() and backward() according to your task. | |
""" | |
def __init__(self, model): | |
super(_BaseWrapper, self).__init__() | |
self.device = next(model.parameters()).device | |
self.model = model | |
self.handlers = [] # a set of hook function handlers | |
def _encode_one_hot(self, ids): | |
one_hot = torch.zeros_like(self.logits).to(self.device) | |
one_hot.scatter_(1, ids, 1.0) | |
return one_hot | |
def forward(self, image): | |
""" | |
Simple classification | |
""" | |
self.model.zero_grad() | |
self.logits = self.model.forward(image) | |
#self.logits = self.model(image)["scores"] | |
#return self.logits | |
self.probs = F.softmax(self.logits, dim=1) | |
return self.probs.sort(dim=1, descending=True) | |
def backward(self, ids): | |
""" | |
Class-specific backpropagation | |
Either way works: | |
1. self.logits.backward(gradient=one_hot, retain_graph=True) | |
2. (self.logits * one_hot).sum().backward(retain_graph=True) | |
""" | |
#print("backward") | |
one_hot = self._encode_one_hot(ids) | |
self.logits.backward(gradient=one_hot, retain_graph=True) | |
def generate(self): | |
raise NotImplementedError | |
def remove_hook(self): | |
""" | |
Remove all the forward/backward hook functions | |
""" | |
for handle in self.handlers: | |
handle.remove() | |
class BackPropagation(_BaseWrapper): | |
def forward(self, image): | |
image["resnet_img"].requires_grad_(True) | |
image["detectron_img"].requires_grad_(True) | |
image["detectron_scale"].requires_grad_(True) | |
self.image = image#.requires_grad_() | |
return super(BackPropagation, self).forward(self.image) | |
def generate(self): | |
gradient = self.image.grad.clone() | |
self.image.grad.zero_() | |
return gradient | |
class GuidedBackPropagation(BackPropagation): | |
""" | |
"Striving for Simplicity: the All Convolutional Net" | |
https://arxiv.org/pdf/1412.6806.pdf | |
Look at Figure 1 on page 8. | |
""" | |
def __init__(self, model): | |
super(GuidedBackPropagation, self).__init__(model) | |
def backward_hook(module, grad_in, grad_out): | |
# Cut off negative gradients | |
if isinstance(module, nn.ReLU): | |
return (torch.clamp(grad_in[0], min=0.0),) | |
for module in self.model.named_modules(): | |
self.handlers.append(module[1].register_backward_hook(backward_hook)) | |
#pass | |
class Deconvnet(BackPropagation): | |
""" | |
"Striving for Simplicity: the All Convolutional Net" | |
https://arxiv.org/pdf/1412.6806.pdf | |
Look at Figure 1 on page 8. | |
""" | |
def __init__(self, model): | |
super(Deconvnet, self).__init__(model) | |
def backward_hook(module, grad_in, grad_out): | |
# Cut off negative gradients and ignore ReLU | |
if isinstance(module, nn.ReLU): | |
return (torch.clamp(grad_out[0], min=0.0),) | |
for module in self.model.named_modules(): | |
self.handlers.append(module[1].register_backward_hook(backward_hook)) | |
def detach_output(output, depth): | |
if not isinstance(output, torch.Tensor) and not isinstance(output, dict) and not isinstance(output, BoxList): | |
tuple_list = [] | |
for item in output: | |
if isinstance(output, torch.Tensor): | |
#print("output: ", output.shape) | |
tuple_list.append(item.detach()) | |
else: | |
tuple_list.append(detach_output(item, depth+1)) | |
return tuple_list | |
elif isinstance(output, dict) or isinstance(output, BoxList): | |
#print("Dict keys: ", output.keys()) | |
return output | |
#print("output: ", output.shape) | |
return output.detach() | |
class GradCAM(_BaseWrapper): | |
""" | |
"Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization" | |
https://arxiv.org/pdf/1610.02391.pdf | |
Look at Figure 2 on page 4 | |
""" | |
def __init__(self, model, candidate_layers=None): | |
super(GradCAM, self).__init__(model) | |
self.fmap_pool = OrderedDict() | |
self.grad_pool = OrderedDict() | |
self.candidate_layers = candidate_layers # list | |
def forward_hook(key): | |
def forward_hook_(module, input, output): | |
# print("forward_hook_:") | |
# Save featuremaps | |
print("key forward: ", key) | |
#output.register_hook(backward_hook(key)) | |
output = detach_output(output, 0) | |
if not isinstance(output, dict) or not isinstance(output, BoxList): | |
self.fmap_pool[key] = output | |
return forward_hook_ | |
def backward_hook(key): | |
#print("key2: ", key) | |
def backward_hook_(module, grad_in, grad_out): | |
# Save the gradients correspond to the featuremaps | |
print("key backward: ", key) | |
self.grad_pool[key] = grad_out[0].detach() | |
return backward_hook_ | |
#print(model.resnet152_model._modules.get("0")) | |
#model.resnet152_model.layer4[0].conv2.register_backward_hook(backward_hook) | |
# If any candidates are not specified, the hook is registered to all the layers. | |
for name, module in self.model.named_modules(): | |
if self.candidate_layers is None or name in self.candidate_layers: | |
#print("name: ", name) | |
self.handlers.append(module.register_forward_hook(forward_hook(name))) | |
#print("name: ", name) | |
#module.retain_grad() | |
#module.require_grad = True | |
#print(name) | |
self.handlers.append(module.register_backward_hook(backward_hook(name))) | |
#print("self.handlers: ", self.handlers) | |
# self.handlers.append(model.resnet152_model._modules.get("0").register_backward_hook(backward_hook('resnet152_model.0'))) | |
# my_layer = model.resnet152_model._modules.get("7") | |
# my_layer_name = 'detection_model.backbone.fpn.fpn_layer4' | |
# self.handlers.append(my_layer.register_backward_hook(backward_hook(my_layer_name))) | |
def _find(self, pool, target_layer): | |
#print(pool.keys()) | |
#print(list(self.model.named_modules())) | |
if target_layer in pool.keys(): | |
return pool[target_layer] | |
else: | |
raise ValueError("Invalid layer name: {}".format(target_layer)) | |
def _compute_grad_weights(self, grads): | |
return F.adaptive_avg_pool2d(grads, 1) | |
def forward(self, sample_list, image_shape): | |
self.image_shape = image_shape | |
return super(GradCAM, self).forward(sample_list) | |
# def select_highest_layer(self): | |
# fmap_list, weight_list = [], [] | |
# module_names = [] | |
# for name, _ in self.model.named_modules(): | |
# module_names.append(name) | |
# module_names.reverse() | |
# | |
# for i in range(self.logits.shape[0]): | |
# counter = 0 | |
# for layer in module_names: | |
# try: | |
# print("Testing layer: {}".format(layer)) | |
# fmaps = self._find(self.fmap_pool, layer) | |
# print("1") | |
# np.shape(fmaps) # Throws error without this line, I have no idea why... | |
# print("2") | |
# fmaps = fmaps[i] | |
# print("3") | |
# grads = self._find(self.grad_pool, layer)[i] | |
# print("4") | |
# import array_check | |
# array_check.check(fmaps) | |
# array_check.check(grads) | |
# # print("counter: {}".format(counter)) | |
# # print("fmaps shape: {}".format(np.shape(fmaps))) | |
# # print("grads shape: {}".format(np.shape(grads))) | |
# nonzeros = np.count_nonzero(grads.detach().cpu().numpy()) | |
# # if True: #counter < 100: | |
# # print("counter: {}".format(counter)) | |
# # #print("fmaps: {}".format(fmaps)) | |
# # print("nonzeros: {}".format(nonzeros)) | |
# # print("fmaps shape: {}".format(np.shape(fmaps))) | |
# # print("grads shape: {}".format(np.shape(grads))) | |
# self._compute_grad_weights(grads) | |
# if nonzeros == 0 or not isinstance(fmaps, torch.Tensor) or not isinstance(grads, torch.Tensor): | |
# counter += 1 | |
# print("Skipped layer: {}".format(layer)) | |
# continue | |
# print("Dismissed the last {} module layers (Note: This number can be inflated if the model contains many nested module layers)".format(counter)) | |
# print("Selected module layer: {}".format(layer)) | |
# fmap_list.append(self._find(self.fmap_pool, layer)[i]) | |
# grads = self._find(self.grad_pool, layer)[i] | |
# weight_list.append(self._compute_grad_weights(grads)) | |
# break | |
# except ValueError: | |
# counter += 1 | |
# except RuntimeError: | |
# counter += 1 | |
# | |
# return fmap_list, weight_list | |
# | |
# def generate_helper(self, fmaps, weights): | |
# gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True) | |
# gcam = F.relu(gcam) | |
# | |
# gcam = F.interpolate( | |
# gcam, self.image_shape, mode="bilinear", align_corners=False | |
# ) | |
# | |
# B, C, H, W = gcam.shape | |
# gcam = gcam.view(B, -1) | |
# gcam -= gcam.min(dim=1, keepdim=True)[0] | |
# gcam /= gcam.max(dim=1, keepdim=True)[0] | |
# gcam = gcam.view(B, C, H, W) | |
# | |
# return gcam | |
# | |
# def generate(self, target_layer, dim=2): | |
# if target_layer == "auto": | |
# fmaps, weights = self.select_highest_layer() | |
# gcam = [] | |
# for i in range(self.logits.shape[0]): | |
# gcam.append(self.generate_helper(fmaps[i].unsqueeze(0), weights[i].unsqueeze(0))) | |
# else: | |
# fmaps = self._find(self.fmap_pool, target_layer) | |
# grads = self._find(self.grad_pool, target_layer) | |
# weights = self._compute_grad_weights(grads) | |
# gcam_tensor = self.generate_helper(fmaps, weights) | |
# gcam = [] | |
# for i in range(self.logits.shape[0]): | |
# tmp = gcam_tensor[i].unsqueeze(0) | |
# gcam.append(tmp) | |
# return gcam | |
def generate(self, target_layer): | |
fmaps = self._find(self.fmap_pool, target_layer) | |
grads = self._find(self.grad_pool, target_layer) | |
# fmaps = fmaps.unsqueeze(2).unsqueeze(3) | |
# grads = grads.unsqueeze(2).unsqueeze(3) | |
# if len(fmaps.shape) == 2: | |
# fmaps = fmaps.unsqueeze(2) | |
# grads = grads.unsqueeze(2) | |
# if len(fmaps.shape) == 3: | |
# fmaps = fmaps.unsqueeze(3) | |
# grads = grads.unsqueeze(3) | |
# print("fmaps.shape: ", fmaps.shape) | |
# print("grads.shape: ", grads.shape) | |
weights = self._compute_grad_weights(grads) | |
# print("weights.shape: ", weights.shape) | |
# print("fmaps.shape: ", fmaps.shape) | |
# print("grads.shape: ", grads.shape) | |
# print("weights.shape: ", weights.shape) | |
gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True) | |
# print("gcam.shape: ", gcam.shape) | |
gcam = F.relu(gcam) | |
# print("gcam.shape: ", gcam.shape) | |
gcam = F.interpolate( | |
gcam, self.image_shape, mode="bilinear", align_corners=False | |
) | |
# print("gcam.shape: ", gcam.shape) | |
B, C, H, W = gcam.shape | |
gcam = gcam.view(B, -1) | |
gcam -= gcam.min(dim=1, keepdim=True)[0] | |
gcam /= gcam.max(dim=1, keepdim=True)[0] | |
gcam = gcam.view(B, C, H, W) | |
return gcam | |
def occlusion_sensitivity( | |
model, batch, ids, mean=None, patch=35, stride=1, n_batches=128 | |
): | |
""" | |
"Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization" | |
https://arxiv.org/pdf/1610.02391.pdf | |
Look at Figure A5 on page 17 | |
Originally proposed in: | |
"Visualizing and Understanding Convolutional Networks" | |
https://arxiv.org/abs/1311.2901 | |
""" | |
torch.set_grad_enabled(False) | |
model.eval() | |
mean = mean if mean else 0 | |
patch_H, patch_W = patch if isinstance(patch, Sequence) else (patch, patch) | |
pad_H, pad_W = patch_H // 2, patch_W // 2 | |
# Padded image | |
batch["raw_image"] = F.pad(batch["raw_image"], (pad_W, pad_W, pad_H, pad_H), value=mean) | |
B, _, H, W = batch["raw_image"].shape | |
new_H = (H - patch_H) // stride + 1 | |
new_W = (W - patch_W) // stride + 1 | |
# Prepare sampling grids | |
anchors = [] | |
grid_h = 0 | |
while grid_h <= H - patch_H: | |
grid_w = 0 | |
while grid_w <= W - patch_W: | |
grid_w += stride | |
anchors.append((grid_h, grid_w)) | |
grid_h += stride | |
# Baseline score without occlusion | |
baseline = model(batch).detach().gather(1, ids) | |
# Compute per-pixel logits | |
scoremaps = [] | |
for i in tqdm(range(0, len(anchors), n_batches), leave=False): | |
#print("Test 1") | |
# batches = [] | |
# batch_ids = [] | |
for grid_h, grid_w in tqdm(anchors[i : i + n_batches]): | |
#print("Test 2") | |
batch_ = _batch_clone(batch) #batch.clone() | |
batch_["raw_image"][..., grid_h : grid_h + patch_H, grid_w : grid_w + patch_W] = mean | |
score = model(batch_).detach().gather(1, ids) | |
scoremaps.append(score) | |
# batches.append(batch_) | |
# batch_ids.append(ids) | |
# batches = _batch_cat(batches) #torch.cat(batches, dim=0) | |
# batch_ids = torch.cat(batch_ids, dim=0) | |
# scores = model(batches).detach().gather(1, batch_ids) | |
# scoremaps += list(torch.split(scores, B)) | |
diffmaps = torch.cat(scoremaps, dim=1) - baseline | |
diffmaps = diffmaps.view(B, new_H, new_W) | |
return diffmaps | |
def _batch_clone(batch): | |
clone = {} | |
for key in batch.keys(): | |
if isinstance(batch[key], torch.Tensor): | |
clone[key] = batch[key].clone() | |
else: | |
clone[key] = copy.deepcopy(batch[key]) | |
return clone | |
def _batch_cat(batch_list): | |
cat_batch = {} | |
for key in batch_list[0].keys(): | |
cat_batch[key] = [batch[key] for batch in batch_list] | |
if isinstance(batch_list[0][key], torch.Tensor): | |
cat_batch[key] = torch.cat(cat_batch[key], dim=0) | |
return cat_batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
sys.path.append('pythia') | |
sys.path.append('vqa-maskrcnn-benchmark') | |
import yaml | |
import torch | |
import torch.nn.functional as F | |
import torchvision.models as models | |
from maskrcnn_benchmark.config import cfg | |
from maskrcnn_benchmark.layers import nms | |
from maskrcnn_benchmark.modeling.detector import build_detection_model | |
from maskrcnn_benchmark.structures.image_list import to_image_list | |
from maskrcnn_benchmark.utils.model_serialization import load_state_dict | |
from pythia.utils.configuration import ConfigNode | |
from pythia.tasks.processors import VocabProcessor, VQAAnswerProcessor | |
from pythia.models.pythia import Pythia | |
from pythia.common.registry import registry | |
from pythia.common.sample import Sample, SampleList | |
torch.backends.cudnn.enabled = False | |
class PythiaVQA(torch.nn.Module): | |
def __init__(self, device): | |
super(PythiaVQA, self).__init__() | |
self.device = device | |
self._init_processors() | |
self.pythia_model = self._build_pythia_model() | |
self.detection_model = self._build_detection_model() | |
self.resnet152_model = self._build_resnet_model() | |
def _init_processors(self): | |
with open("model_data/pythia.yaml") as f: | |
config = yaml.load(f) | |
config = ConfigNode(config) | |
# Remove warning | |
config.training_parameters.evalai_inference = True | |
registry.register("config", config) | |
self.config = config | |
vqa_config = config.task_attributes.vqa.dataset_attributes.vqa2 | |
text_processor_config = vqa_config.processors.text_processor | |
answer_processor_config = vqa_config.processors.answer_processor | |
text_processor_config.params.vocab.vocab_file = "model_data/vocabulary_100k.txt" | |
answer_processor_config.params.vocab_file = "model_data/answers_vqa.txt" | |
# Add preprocessor as that will needed when we are getting questions from user | |
self.text_processor = VocabProcessor(text_processor_config.params) | |
self.answer_processor = VQAAnswerProcessor(answer_processor_config.params) | |
registry.register("vqa2_text_processor", self.text_processor) | |
registry.register("vqa2_answer_processor", self.answer_processor) | |
registry.register("vqa2_num_final_outputs", | |
self.answer_processor.get_vocab_size()) | |
def _build_pythia_model(self): | |
state_dict = torch.load('model_data/pythia.pth') | |
model_config = self.config.model_attributes.pythia | |
model_config.model_data_dir = "/visinf/home/vilab22/Documents/RemoteProjects/dlcv/pythia/" #/content/ | |
model = Pythia(model_config) | |
model.build() | |
model.init_losses_and_metrics() | |
if list(state_dict.keys())[0].startswith('module') and \ | |
not hasattr(model, 'module'): | |
state_dict = self._multi_gpu_state_to_single(state_dict) | |
model.load_state_dict(state_dict) | |
model.to(self.device) | |
model.eval() | |
return model | |
def _build_resnet_model(self): | |
resnet152 = models.resnet152(pretrained=True) | |
resnet152.eval() | |
modules = list(resnet152.children())[:-2] | |
resnet152_model = torch.nn.Sequential(*modules) | |
resnet152_model.to(self.device) | |
resnet152_model.eval() | |
return resnet152_model | |
def _multi_gpu_state_to_single(self, state_dict): | |
new_sd = {} | |
for k, v in state_dict.items(): | |
if not k.startswith('module.'): | |
raise TypeError("Not a multiple GPU state of dict") | |
k1 = k[7:] | |
new_sd[k1] = v | |
return new_sd | |
def _build_detection_model(self): | |
cfg.merge_from_file('model_data/detectron_model.yaml') | |
cfg.freeze() | |
model = build_detection_model(cfg) | |
checkpoint = torch.load('model_data/detectron_model.pth', | |
map_location=torch.device("cpu")) | |
load_state_dict(model, checkpoint.pop("model")) | |
model.to(self.device) | |
model.eval() | |
return model | |
def _process_feature_extraction(self, output, | |
im_scales, | |
feat_name='fc6', | |
conf_thresh=0.2): | |
batch_size = len(output[0]["proposals"]) | |
n_boxes_per_image = [len(_) for _ in output[0]["proposals"]] | |
score_list = output[0]["scores"].split(n_boxes_per_image) | |
score_list = [torch.nn.functional.softmax(x, -1) for x in score_list] | |
feats = output[0][feat_name].split(n_boxes_per_image) | |
cur_device = score_list[0].device | |
feat_list = [] | |
for i in range(batch_size): | |
dets = output[0]["proposals"][i].bbox / im_scales[i] | |
scores = score_list[i] | |
max_conf = torch.zeros((scores.shape[0])).to(cur_device) | |
for cls_ind in range(1, scores.shape[1]): | |
cls_scores = scores[:, cls_ind] | |
keep = nms(dets, cls_scores, 0.5) | |
max_conf[keep] = torch.where(cls_scores[keep] > max_conf[keep], | |
cls_scores[keep], | |
max_conf[keep]) | |
keep_boxes = torch.argsort(max_conf, descending=True)[:100] | |
feat_list.append(feats[i][keep_boxes]) | |
return feat_list | |
def masked_unk_softmax(self, x, dim, mask_idx): | |
x1 = F.softmax(x, dim=dim) | |
x1[:, mask_idx] = 0 | |
x1_sum = torch.sum(x1, dim=1, keepdim=True) | |
y = x1 / x1_sum | |
return y | |
def get_resnet_features(self, img): | |
if img.shape[0] == 1: | |
img = img.expand(3, -1, -1) | |
img = img.unsqueeze(0).to(self.device) | |
features = self.resnet152_model.forward(img).permute(0, 2, 3, 1) | |
features = features.view(196, 2048) | |
return features | |
def get_detectron_features(self, im, im_scale): | |
current_img_list = to_image_list(im, size_divisible=32) | |
current_img_list = current_img_list.to(self.device) | |
with torch.no_grad(): | |
output = self.detection_model.forward(current_img_list) | |
feat_list = self._process_feature_extraction(output, im_scale, | |
'fc6', 0.2) | |
return feat_list[0] | |
def forward(self, batch): | |
question = batch['question'][0] | |
detectron_img = batch['detectron_img'].squeeze() | |
detectron_scale = [batch['detectron_scale'].item()] | |
resnet_img = batch['resnet_img'].squeeze() | |
detectron_features = self.get_detectron_features(detectron_img, detectron_scale) | |
# with torch.no_grad(): | |
resnet_features = self.get_resnet_features(resnet_img) | |
sample = Sample() | |
processed_text = self.text_processor({"text": question}) | |
sample.text = processed_text["text"] | |
sample.text_len = len(processed_text["tokens"]) | |
sample.image_feature_0 = detectron_features | |
sample.image_info_0 = Sample({ | |
"max_features": torch.tensor(100, dtype=torch.long) | |
}) | |
sample.image_feature_1 = resnet_features | |
sample_list = SampleList([sample]) | |
sample_list = sample_list.to(self.device) | |
scores = self.pythia_model.forward(sample_list)["scores"] | |
return scores |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment