Created
August 10, 2018 05:56
-
-
Save lichengunc/48380fa6bf3cc2af6201443180db69fc to your computer and use it in GitHub Desktop.
matt_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This code will take torch.utils.data.Dataset as abstract class and implement the two important | |
functions: | |
1) __len__ : len(dataset) returns the size of the dataset | |
2) __getitem__: dataset[i] can be used to get i-th sample | |
data_json has | |
0. refs: [{ref_id, ann_id, box, image_id, split, category_id, sent_ids, att_wds}] | |
1. images: [{image_id, ref_ids, file_name, width, height, h5_id}] | |
2. anns: [{ann_id, category_id, image_id, box, h5_id}] | |
3. sentences: [{sent_id, tokens, h5_id}] | |
4. word_to_ix: {word: ix} | |
5. att_to_ix : {att_wd: ix} | |
6. att_to_cnt: {att_wd: cnt} | |
7. label_length: L | |
Note, box in [xywh] format | |
data_h5 has | |
/labels is (M, max_length) uint32 array of encoded labels, zeros padded | |
""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os.path as osp | |
import numpy as np | |
import h5py | |
import json | |
import random | |
import collections | |
import torch | |
import torch.utils.data as data | |
from torch.autograd import Variable | |
def preprocess_sent(sent_str): | |
""" | |
remove punctuation, lower case, stripped | |
""" | |
punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~''' | |
no_punct = "" | |
for char in sent_str: | |
if char not in punctuations: | |
no_punct += char | |
no_punct = no_punct.lower().strip() | |
return no_punct | |
def xywh_to_xyxy(boxes): | |
"""Convert [x y w h] box format to [x1 y1 x2 y2] format.""" | |
return np.hstack((boxes[:, 0:2], boxes[:, 0:2] + boxes[:, 2:4] - 1)) | |
def matt_collate(batch): | |
""" | |
Used to collate mattnet_dataset[i]'s into a batch. | |
""" | |
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" | |
if torch.is_tensor(batch[0]): | |
# collate head, ref_labels, lfeats, dif_lfeats, cxt_lfeats | |
return [_ for _ in batch] | |
elif type(batch[0]).__module__ == 'numpy': | |
# collate im_info | |
return [_ for _ in batch] | |
elif isinstance(batch[0], collections.Sequence): | |
# collate list | |
return [_ for _ in batch] | |
elif isinstance(batch[0], int): | |
# collate image_id | |
return [_ for _ in batch] | |
elif isinstance(batch[0], collections.Mapping): | |
return {key: matt_collate([d[key] for d in batch]) for key in batch[0]} | |
raise TypeError((error_msg.format(type(batch[0])))) | |
class ImageSampler(data.sampler.Sampler): | |
""" | |
Used to sample train/val/test images | |
Note we will iterate the dataset using image_ids! | |
""" | |
def __init__(self, image_ids): | |
self.image_ids = image_ids | |
def __iter__(self): | |
return iter(self.image_ids) | |
def __len__(self): | |
return len(self.image_ids) | |
class MAttDataset(data.Dataset): | |
def __init__(self, data_json, split, head_feats_dir, opt, | |
word_to_ix=None, att_to_ix=None, encode_from_phrase=False): | |
# load the json file which contains info about the dataset | |
print('Dataset loading data.json: ', data_json) | |
self.info = json.load(open(data_json)) | |
self.word_to_ix = self.info['word_to_ix'] if word_to_ix is None else word_to_ix | |
self.ix_to_word = {ix: wd for wd, ix in self.word_to_ix.items()} | |
self.vocab_size = len(self.ix_to_word) | |
print('vocab size is ', self.vocab_size) | |
self.cat_to_ix = self.info['cat_to_ix'] | |
self.ix_to_cat = {ix: cat for cat, ix in self.cat_to_ix.items()} | |
print('object cateogry size is ', len(self.ix_to_cat)) | |
self.images = self.info['images'] | |
self.anns = self.info['anns'] | |
self.refs = self.info['refs'] | |
self.sentences = self.info['sentences'] | |
print('we have %s images.' % len(self.images)) | |
print('we have %s anns.' % len(self.anns)) | |
print('we have %s refs.' % len(self.refs)) | |
print('we have %s sentences.' % len(self.sentences)) | |
self.label_length = self.info['label_length'] | |
print('label_length is ', self.label_length) | |
# construct mapping | |
self.Refs = {ref['ref_id']: ref for ref in self.refs} | |
self.Images = {image['image_id']: image for image in self.images} | |
self.Anns = {ann['ann_id']: ann for ann in self.anns} | |
self.Sentences = {sent['sent_id']: sent for sent in self.sentences} | |
self.annToRef = {ref['ann_id']: ref for ref in self.refs} | |
self.sentToRef = {sent_id: ref for ref in self.refs for sent_id in ref['sent_ids']} | |
# prepare attributes | |
self.att_to_ix = self.info['att_to_ix'] if att_to_ix is None else att_to_ix | |
self.ix_to_att = {ix: wd for wd, ix in self.att_to_ix.items()} | |
self.att_to_cnt = self.info['att_to_cnt'] | |
self.attribute_size = len(self.att_to_ix) | |
# image_ids of each split | |
self.split = split | |
self.split_image_ids = [] | |
for image_id, image in self.Images.items(): | |
if self.Refs[image['ref_ids'][0]]['split'] == self.split: | |
self.split_image_ids += [image_id] | |
print('assigned %d images to split %s' % (len(self.split_image_ids), self.split)) | |
# other options | |
self.seq_per_ref = opt.get('seq_per_ref', 3) | |
self.sample_ratio = opt['visual_sample_ratio'] | |
self.num_cxt = opt.get('num_cxt', 5) | |
self.with_st = opt.get('with_st', 1) | |
self.head_feats_dir = head_feats_dir | |
self.encode_from_phrase = encode_from_phrase | |
def __len__(self): | |
return len(self.split_image_ids) | |
def __getitem__(self, index): | |
""" | |
Args: | |
index (int): image_id | |
Returns: | |
image_id : current index | |
head : float (1, 1024, H, W) | |
im_info : ndarray float32 [[im_h, im_w, im_scale]], (1, 3) | |
image_ann_ids : N annotated objects in this image | |
ref_ids : n refs, where N = n x seq_per_ref | |
ref_ann_ids : n positive anns --> ref_pool5, ref_fc7 | |
ref_sent_ids : n positive sent_ids | |
ref_labels : long (n, label_length) | |
ref_cxt_ann_ids : (n, num_cxt) ann_ids for each of ref_id, -1 padded | |
ref_Feats : ref_lfeats : float (n, 5) | |
ref_dif_lfeats : float (n, 25) | |
ref_cxt_lfeats : float (n, num_cxt, 5) | |
neg_ann_ids : n negative anns --> ref_pool5, ref_fc7 | |
neg_sent_ids : n negative sent_ids | |
neg_labels : long (n, label_length) | |
neg_cxt_ann_ids : (n, num_cxt) ann_ids for each of neg_id, -1 padded | |
neg_Feats : neg_lfeats : float (n, 5) | |
neg_dif_lfeats : float (n, 25) | |
neg_cxt_lfeats : float (n, num_cxt, 5) | |
""" | |
# get image head feats | |
image_id = index | |
head, im_info = self.image_to_head(image_id) | |
head = torch.from_numpy(head) | |
image_ann_ids = self.Images[image_id]['ann_ids'] | |
# expand ref_ids by seq_per_ref | |
ref_ids = self.Images[image_id]['ref_ids'] | |
image_ref_ids = self.expand_list(ref_ids, self.seq_per_ref) | |
# sample all ids | |
ref_ann_ids, ref_sent_ids = [], [] | |
neg_ann_ids, neg_sent_ids = [], [] | |
for ref_id in ref_ids: | |
ref_ann_id = self.Refs[ref_id]['ann_id'] | |
# pos ids | |
ref_ann_ids += [ref_ann_id] * self.seq_per_ref | |
ref_sent_ids += self.fetch_sent_ids_by_ref_id(ref_id, self.seq_per_ref) | |
# neg ids | |
cur_ann_ids, cur_sent_ids = self.sample_neg_ids(ref_ann_id, self.seq_per_ref, self.sample_ratio) | |
neg_ann_ids += cur_ann_ids | |
neg_sent_ids += cur_sent_ids | |
# compute all lfeats | |
ref_lfeats = torch.from_numpy(self.compute_lfeats(ref_ann_ids)) | |
ref_dif_lfeats = torch.from_numpy(self.compute_dif_lfeats(ref_ann_ids)) | |
neg_lfeats = torch.from_numpy(self.compute_lfeats(neg_ann_ids)) | |
neg_dif_lfeats = torch.from_numpy(self.compute_dif_lfeats(neg_ann_ids)) | |
# fetch labels | |
ref_labels = torch.from_numpy(self.fetch_labels(ref_sent_ids)).long() | |
neg_labels = torch.from_numpy(self.fetch_labels(neg_sent_ids)).long() | |
# fetch context info: cxt_lfeats, cxt_ann_ids | |
ref_cxt_lfeats, ref_cxt_ann_ids = self.fetch_cxt_info(ref_ann_ids, self.num_cxt) | |
ref_cxt_lfeats = torch.from_numpy(ref_cxt_lfeats) | |
neg_cxt_lfeats, neg_cxt_ann_ids = self.fetch_cxt_info(neg_ann_ids, self.num_cxt) | |
neg_cxt_lfeats = torch.from_numpy(neg_cxt_lfeats) | |
# return | |
data = {} | |
data['image_id'] = image_id | |
data['head'] = head | |
data['im_info'] = im_info | |
data['image_ann_ids'] = image_ann_ids | |
data['ref_ids'] = image_ref_ids | |
data['ref_ann_ids'] = ref_ann_ids | |
data['ref_sent_ids'] = ref_sent_ids | |
data['ref_labels'] = ref_labels | |
data['ref_cxt_ann_ids'] = ref_cxt_ann_ids | |
data['ref_Feats'] = {'lfeats': ref_lfeats, 'dif_lfeats': ref_dif_lfeats, 'cxt_lfeats': ref_cxt_lfeats} | |
data['neg_ann_ids'] = neg_ann_ids | |
data['neg_sent_ids'] = neg_sent_ids | |
data['neg_labels'] = neg_labels | |
data['neg_cxt_ann_ids'] = neg_cxt_ann_ids | |
data['neg_Feats'] = {'lfeats': neg_lfeats, 'dif_lfeats': neg_dif_lfeats, 'cxt_lfeats': neg_cxt_lfeats} | |
return data | |
# expand list by seq_per_ref, i.e., [a,b], 3 -> [aaabbb] | |
def expand_list(self, L, n): | |
out = [] | |
for l in L: | |
out += [l] * n | |
return out | |
def fetch_sent_ids_by_ref_id(self, ref_id, num_sents): | |
""" | |
Sample #num_sents sents for each ref_id. | |
""" | |
sent_ids = list(self.Refs[ref_id]['sent_ids']) | |
if len(sent_ids) < num_sents: | |
append_sent_ids = [random.choice(sent_ids) for _ in range(num_sents - len(sent_ids))] | |
sent_ids += append_sent_ids | |
else: | |
random.shuffle(sent_ids) | |
sent_ids = sent_ids[:num_sents] | |
assert len(sent_ids) == num_sents | |
return sent_ids | |
def sample_neg_ids(self, ann_id, seq_per_ref, sample_ratio): | |
"""Return | |
- neg_ann_ids : list of ann_ids that are negative to target ann_id | |
- neg_sent_ids: list of sent_ids that are negative to target ann_id | |
""" | |
st_ref_ids, st_ann_ids, dt_ref_ids, dt_ann_ids = self.fetch_neighbour_ids(ann_id) | |
# neg ann | |
neg_ann_ids, neg_sent_ids = [], [] | |
for k in range(seq_per_ref): | |
# neg_ann_id for negative visual representation: mainly from same-type objects | |
if len(st_ann_ids) > 0 and np.random.uniform(0, 1, 1) < sample_ratio: | |
neg_ann_id = random.choice(st_ann_ids) | |
elif len(dt_ann_ids) > 0: | |
neg_ann_id = random.choice(dt_ann_ids) | |
else: | |
# awkward case: I just randomly sample from st_ann_ids + dt_ann_ids, or -1 | |
if len(st_ann_ids + dt_ann_ids) > 0: | |
neg_ann_id = random.choice(st_ann_ids + dt_ann_ids) | |
else: | |
neg_ann_id = -1 | |
neg_ann_ids += [neg_ann_id] | |
# neg_ref_id for negative language representations: mainly from same-type "referred" objects | |
if len(st_ref_ids) > 0 and np.random.uniform(0, 1, 1) < sample_ratio: | |
neg_ref_id = random.choice(st_ref_ids) | |
elif len(dt_ref_ids) > 0: | |
neg_ref_id = random.choice(dt_ref_ids) | |
else: | |
neg_ref_id = random.choice(self.Refs.keys()) | |
neg_sent_id = random.choice(self.Refs[neg_ref_id]['sent_ids']) | |
neg_sent_ids += [neg_sent_id] | |
return neg_ann_ids, neg_sent_ids | |
def fetch_neighbour_ids(self, ref_ann_id): | |
""" | |
For a given ref_ann_id, we return | |
- st_ann_ids: same-type neighbouring ann_ids (not including itself) | |
- dt_ann_ids: different-type neighbouring ann_ids | |
Ordered by distance to the input ann_id | |
""" | |
ref_ann = self.Anns[ref_ann_id] | |
x, y, w, h = ref_ann['box'] | |
rx, ry = x+w/2, y+h/2 | |
def compare(ann_id0, ann_id1): | |
x, y, w, h = self.Anns[ann_id0]['box'] | |
ax0, ay0 = x+w/2, y+h/2 | |
x, y, w, h = self.Anns[ann_id1]['box'] | |
ax1, ay1 = x+w/2, y+h/2 | |
# closer --> former | |
if (rx-ax0)**2 + (ry-ay0)**2 <= (rx-ax1)**2 + (ry-ay1)**2: | |
return -1 | |
else: | |
return 1 | |
image = self.Images[ref_ann['image_id']] | |
ann_ids = list(image['ann_ids']) # copy in case the raw list is changed | |
ann_ids = sorted(ann_ids, cmp=compare) | |
st_ref_ids, st_ann_ids, dt_ref_ids, dt_ann_ids = [], [], [], [] | |
for ann_id in ann_ids: | |
if ann_id != ref_ann_id: | |
if self.Anns[ann_id]['category_id'] == ref_ann['category_id']: | |
st_ann_ids += [ann_id] | |
if ann_id in self.annToRef: | |
st_ref_ids += [self.annToRef[ann_id]['ref_id']] | |
else: | |
dt_ann_ids += [ann_id] | |
if ann_id in self.annToRef: | |
dt_ref_ids += [self.annToRef[ann_id]['ref_id']] | |
return st_ref_ids, st_ann_ids, dt_ref_ids, dt_ann_ids | |
def fetch_cxt_info(self, ann_ids, topK): | |
""" | |
Return | |
- cxt_lfeats : ndarray float32 (#ann_ids, topK, 5), padded with 0 | |
- cxt_ann_ids: list[[ann_id]] of size (#ann_ids, topK), padded with -1 | |
Note we only use neighbouring "different"(+"same") objects for computing context objects, zeros padded. | |
""" | |
cxt_lfeats = np.zeros((len(ann_ids), topK, 5), dtype=np.float32) | |
cxt_ann_ids = [[-1 for _ in range(topK)] for _ in range(len(ann_ids))] # (#ann_ids, topK) | |
for i, ref_ann_id in enumerate(ann_ids): | |
if ref_ann_id == -1: | |
continue | |
# reference box | |
rbox = self.Anns[ref_ann_id]['box'] | |
rcx, rcy, rw, rh = rbox[0]+rbox[2]/2, rbox[1]+rbox[3]/2, rbox[2], rbox[3] | |
rw += 1e-5 | |
rh += 1e-5 | |
# candidate boxes | |
_, st_ann_ids, _, dt_ann_ids = self.fetch_neighbour_ids(ref_ann_id) | |
if self.with_st > 0: | |
cand_ann_ids = dt_ann_ids + st_ann_ids | |
else: | |
cand_ann_ids = dt_ann_ids | |
cand_ann_ids = cand_ann_ids[:topK] | |
for j, cand_ann_id in enumerate(cand_ann_ids): | |
cand_ann = self.Anns[cand_ann_id] | |
cbox = cand_ann['box'] | |
cx1, cy1, cw, ch = cbox[0], cbox[1], cbox[2], cbox[3] | |
cxt_lfeats[i, j, :] = np.array([(cx1-rcx)/rw, (cy1-rcy)/rh, (cx1+cw-rcx)/rw, (cy1+ch-rcy)/rh, cw*ch/(rw*rh)]) | |
cxt_ann_ids[i][j] = cand_ann_id | |
return cxt_lfeats, cxt_ann_ids | |
def compute_lfeats(self, ann_ids): | |
# return ndarray float32 (#ann_ids, 5) | |
lfeats = np.zeros((len(ann_ids), 5), dtype=np.float32) | |
for ix, ann_id in enumerate(ann_ids): | |
if ann_id == -1: | |
continue | |
ann = self.Anns[ann_id] | |
image = self.Images[ann['image_id']] | |
x, y, w, h = ann['box'] | |
ih, iw = image['height'], image['width'] | |
lfeats[ix] = np.array([[x/iw, y/ih, (x+w-1)/iw, (y+h-1)/ih, w*h/(iw*ih)]], np.float32) | |
return lfeats | |
def compute_dif_lfeats(self, ann_ids, topK=5): | |
# return ndarray float32 (#ann_ids, 5*topK) | |
dif_lfeats = np.zeros((len(ann_ids), 5*topK), dtype=np.float32) | |
for i, ref_ann_id in enumerate(ann_ids): | |
if ref_ann_id == -1: | |
continue | |
# reference box | |
rbox = self.Anns[ref_ann_id]['box'] | |
rcx, rcy, rw, rh = rbox[0]+rbox[2]/2, rbox[1]+rbox[3]/2, rbox[2], rbox[3] | |
rw += 1e-5 | |
rh += 1e-5 | |
# candidate boxes | |
_, st_ann_ids, _, _ = self.fetch_neighbour_ids(ref_ann_id) | |
for j, cand_ann_id in enumerate(st_ann_ids[:topK]): | |
cbox = self.Anns[cand_ann_id]['box'] | |
cx1, cy1, cw, ch = cbox[0], cbox[1], cbox[2], cbox[3] | |
dif_lfeats[i, j*5:(j+1)*5] = \ | |
np.array([(cx1-rcx)/rw, (cy1-rcy)/rh, (cx1+cw-rcx)/rw, (cy1+ch-rcy)/rh, cw*ch/(rw*rh)]) | |
return dif_lfeats | |
# weights = 1/sqrt(cnt) | |
def get_attribute_weights(self, scale=10): | |
# return weights for each concept, ordered by cpt_ix | |
cnts = [self.att_to_cnt[self.ix_to_att[ix]] for ix in range(self.attribute_size)] | |
cnts = np.array(cnts) | |
weights = 1/(cnts**0.5 + 1e-5) | |
weights = (weights - np.min(weights)) / (np.max(weights) - np.min(weights)) | |
weights = weights * (scale-1) + 1 | |
return torch.from_numpy(weights).float() | |
def fetch_attribute_label(self, ref_ann_ids): | |
"""Return | |
- labels : Variable float (N, attribute_size) | |
- select_ixs: Variable long (N, ) | |
""" | |
labels = np.zeros((len(ref_ann_ids), self.attribute_size)) | |
select_ixs = [] | |
for i, ref_ann_id in enumerate(ref_ann_ids): | |
ref = self.annToRef[ref_ann_id] | |
if len(ref['att_wds']) > 0: | |
select_ixs += [i] | |
for wd in ref['att_wds']: | |
labels[i, self.att_to_ix[wd]] = 1 | |
return Variable(torch.from_numpy(labels).float().cuda()), Variable(torch.LongTensor(select_ixs).cuda()) | |
def image_to_head(self, image_id): | |
"""Returns | |
head: float32 (1, 1024, H, W) | |
im_info: float32 [[im_h, im_w, im_scale]] | |
""" | |
feats_h5 = osp.join(self.head_feats_dir, str(image_id)+'.h5') | |
feats = h5py.File(feats_h5, 'r') | |
head, im_info = feats['head'], feats['im_info'] | |
return np.array(head), np.array(im_info) | |
def encode_phrases(self, phrase_list): | |
# encode to np.int64 (num_phrase_list, label_length) | |
L = np.zeros((len(phrase_list), self.label_length), dtype=np.int64) | |
for i, raw_phrase in enumerate(phrase_list): | |
phrase = preprocess_sent(raw_phrase) | |
# in case phrase is void | |
if len(phrase) > 0: | |
for j, w in enumerate(phrase.split()): | |
if j < self.label_length: | |
L[i, j] = self.word_to_ix[w] if w in self.word_to_ix else self.word_to_ix['<UNK>'] | |
else: | |
L[i, 0] = self.word_to_ix['<UNK>'] | |
return L | |
def encode_labels(self, tokens_list): | |
# encode to np.int64 (num_tokens_list, label_length) | |
L = np.zeros((len(tokens_list), self.label_length), dtype=np.int64) | |
for i, tokens in enumerate(tokens_list): | |
for j, w in enumerate(tokens): | |
if j < self.label_length: | |
L[i, j] = self.word_to_ix[w] | |
return L | |
def fetch_labels(self, sent_ids): | |
""" | |
Return: int64 (num_sents, label_length) | |
""" | |
if self.encode_from_phrase: | |
phrase_list = [self.Sentences[sent_id]['sent'] for sent_id in sent_ids] | |
labels = self.encode_phrases(phrase_list) | |
else: | |
tokens_list = [self.Sentences[sent_id]['tokens'] for sent_id in sent_ids] | |
labels = self.encode_labels(tokens_list) | |
return labels | |
def decode_labels(self, labels): | |
""" | |
labels: int32 (n, label_length) zeros padded in end | |
return: list of sents in string format | |
""" | |
decoded_sent_strs = [] | |
num_sents = labels.shape[0] | |
for i in range(num_sents): | |
label = labels[i].tolist() | |
sent_str = ' '.join([self.ix_to_word[int(i)] for i in label if i != 0]) | |
decoded_sent_strs.append(sent_str) | |
return decoded_sent_strs | |
def fetch_box_feats(self, mrcn, boxes, net_conv, im_info): | |
""" | |
Return: | |
- pool5 (n, 1024, 7, 7) | |
- fc7 (n, 2048, 7, 7) | |
""" | |
pool5, fc7 = mrcn.box_to_spatial_fc7(net_conv, im_info, boxes) | |
return pool5, fc7 | |
def fetch_feats(self, mrcn, head, im_info, image_id, image_ann_ids, ref_ann_ids, cxt_ann_ids): | |
""" | |
Inputs | |
- mrcn : mask faster-rcnn instance, has to be loaded from outside for the re-use of train/val/test | |
- head : float (1, 1024, H, W) | |
- im_info : ndarray float32 [[im_h, im_w, im_scale]], (1, 3) | |
- image_ann_ids : N annotated ann_ids in the image | |
- ref_ann_ids : n referred objects | |
- cxt_ann_ids : (n, num_cxt) context objects | |
Returns | |
- ref_fc7 : Variable float (n, 2048, 7, 7) | |
- ref_pool5 : Variable float (n, 1024, 7, 7) | |
- cxt_fc7 : Variable float (n, num_cxt, 2048) or None | |
Dont't forget to detach these feats. | |
""" | |
# get all anns' fc7 and pool5 | |
ann_boxes = xywh_to_xyxy(np.vstack([self.Anns[ann_id]['box'] for ann_id in image_ann_ids])) | |
pool5, fc7 = mrcn.box_to_spatial_fc7(head, im_info, ann_boxes) | |
# index_select ref_ann_ids | |
ref_pool5 = Variable(head.data.new(len(ref_ann_ids), 1024, 7, 7).zero_()) | |
ref_fc7 = Variable(head.data.new(len(ref_ann_ids), 2048, 7, 7).zero_()) | |
for i, ann_id in enumerate(ref_ann_ids): | |
if ann_id == -1: | |
continue | |
ix = image_ann_ids.index(ann_id) | |
ref_pool5[i], ref_fc7[i] = pool5[ix], fc7[ix] | |
# index_select cxt_fc7 | |
if cxt_ann_ids is not None: | |
cxt_fc7 = Variable(head.data.new(len(cxt_ann_ids), self.num_cxt, 2048).zero_()) | |
for i in range(len(cxt_ann_ids)): | |
for j in range(self.num_cxt): | |
cxt_ann_id = cxt_ann_ids[i][j] | |
if cxt_ann_id != -1: | |
ix = image_ann_ids.index(cxt_ann_id) | |
cxt_fc7[i, j, :] = fc7[ix].mean(2).mean(1) | |
else: | |
# we don't use cxt_fc7 when testing attributes | |
cxt_fc7 = None | |
# return | |
return ref_fc7, ref_pool5, cxt_fc7 | |
def getAttributeBatch(self, image_id): | |
image = self.Images[image_id] | |
head, im_info = self.image_to_head(image_id) | |
head = Variable(torch.from_numpy(head).cuda()) | |
# fetch ann_ids owning attributes | |
ref_ids = image['ref_ids'] | |
ann_ids = [self.Refs[ref_id]['ann_id'] for ref_id in ref_ids] | |
ann_boxes = xywh_to_xyxy(np.vstack([self.Anns[ann_id]['box'] for ann_id in ann_ids])) | |
lfeats = Variable(torch.from_numpy(self.compute_lfeats(ann_ids)).cuda()) | |
dif_lfeats = Variable(torch.from_numpy(self.compute_dif_lfeats(ann_ids)).cuda()) | |
# return data | |
data = {} | |
data['image_id'] = image_id | |
data['head'] = head | |
data['im_info'] = im_info | |
data['ref_ids'] = ref_ids | |
data['ann_ids'] = ann_ids | |
data['Feats'] = {'lfeats': lfeats, 'dif_lfeats': dif_lfeats} | |
return data | |
def getImageBatch(self, image_id, sent_ids=None): | |
""" | |
Args: | |
image_id : int | |
Returns: | |
image_id : same as input | |
head : float (1, 1024, H, W) | |
im_info : ndarray float32 [[im_h, im_w, im_scale]], (1, 3) | |
ann_ids : N annotated objects in the image | |
cxt_ann_ids : float (N, num_cxt) | |
sent_ids : M sent_ids used in this image | |
Feats : - lfeats float (N, 5) | |
- dif_lfeats float (N, 25) | |
- cxt_lfeats float (N, num_cxt, 5) | |
labels : long (M, label_length) | |
gd_ixs : list of N ixs | |
""" | |
image = self.Images[image_id] | |
head, im_info = self.image_to_head(image_id) | |
head = Variable(torch.from_numpy(head).cuda()) | |
ann_ids = self.Images[image_id]['ann_ids'] | |
# compute all lfeats | |
lfeats = Variable(torch.from_numpy(self.compute_lfeats(ann_ids)).cuda()) | |
dif_lfeats = Variable(torch.from_numpy(self.compute_dif_lfeats(ann_ids)).cuda()) | |
# get cxt info | |
cxt_lfeats, cxt_ann_ids = self.fetch_cxt_info(ann_ids, self.num_cxt) | |
cxt_lfeats = Variable(torch.from_numpy(cxt_lfeats).cuda()) | |
# fetch sent_ids and labels | |
gd_ixs = [] | |
if sent_ids is None: | |
sent_ids = [] | |
for ref_id in image['ref_ids']: | |
ref = self.Refs[ref_id] | |
for sent_id in ref['sent_ids']: | |
sent_ids += [sent_id] | |
gd_ixs += [ann_ids.index(ref['ann_id'])] | |
else: | |
# given sent_id, we find the gd_ix | |
for sent_id in sent_ids: | |
ref = self.sentToRef[sent_id] | |
gd_ixs += [ann_ids.index(ref['ann_id'])] | |
labels = Variable(torch.from_numpy(self.fetch_labels(sent_ids)).long().cuda()) | |
# return data | |
data = {} | |
data['image_id'] = image_id | |
data['head'] = head | |
data['im_info'] = im_info | |
data['ann_ids'] = ann_ids | |
data['cxt_ann_ids'] = cxt_ann_ids | |
data['sent_ids'] = sent_ids | |
data['gd_ixs'] = gd_ixs | |
data['Feats'] = {'lfeats': lfeats, 'dif_lfeats': dif_lfeats, 'cxt_lfeats': cxt_lfeats} | |
data['labels'] = labels | |
return data | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment