Skip to content

Instantly share code, notes, and snippets.

@lichengunc
Created August 10, 2018 05:56
Show Gist options
  • Save lichengunc/48380fa6bf3cc2af6201443180db69fc to your computer and use it in GitHub Desktop.
Save lichengunc/48380fa6bf3cc2af6201443180db69fc to your computer and use it in GitHub Desktop.
matt_dataset.py
"""
This code will take torch.utils.data.Dataset as abstract class and implement the two important
functions:
1) __len__ : len(dataset) returns the size of the dataset
2) __getitem__: dataset[i] can be used to get i-th sample
data_json has
0. refs: [{ref_id, ann_id, box, image_id, split, category_id, sent_ids, att_wds}]
1. images: [{image_id, ref_ids, file_name, width, height, h5_id}]
2. anns: [{ann_id, category_id, image_id, box, h5_id}]
3. sentences: [{sent_id, tokens, h5_id}]
4. word_to_ix: {word: ix}
5. att_to_ix : {att_wd: ix}
6. att_to_cnt: {att_wd: cnt}
7. label_length: L
Note, box in [xywh] format
data_h5 has
/labels is (M, max_length) uint32 array of encoded labels, zeros padded
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path as osp
import numpy as np
import h5py
import json
import random
import collections
import torch
import torch.utils.data as data
from torch.autograd import Variable
def preprocess_sent(sent_str):
"""
remove punctuation, lower case, stripped
"""
punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''
no_punct = ""
for char in sent_str:
if char not in punctuations:
no_punct += char
no_punct = no_punct.lower().strip()
return no_punct
def xywh_to_xyxy(boxes):
"""Convert [x y w h] box format to [x1 y1 x2 y2] format."""
return np.hstack((boxes[:, 0:2], boxes[:, 0:2] + boxes[:, 2:4] - 1))
def matt_collate(batch):
"""
Used to collate mattnet_dataset[i]'s into a batch.
"""
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
if torch.is_tensor(batch[0]):
# collate head, ref_labels, lfeats, dif_lfeats, cxt_lfeats
return [_ for _ in batch]
elif type(batch[0]).__module__ == 'numpy':
# collate im_info
return [_ for _ in batch]
elif isinstance(batch[0], collections.Sequence):
# collate list
return [_ for _ in batch]
elif isinstance(batch[0], int):
# collate image_id
return [_ for _ in batch]
elif isinstance(batch[0], collections.Mapping):
return {key: matt_collate([d[key] for d in batch]) for key in batch[0]}
raise TypeError((error_msg.format(type(batch[0]))))
class ImageSampler(data.sampler.Sampler):
"""
Used to sample train/val/test images
Note we will iterate the dataset using image_ids!
"""
def __init__(self, image_ids):
self.image_ids = image_ids
def __iter__(self):
return iter(self.image_ids)
def __len__(self):
return len(self.image_ids)
class MAttDataset(data.Dataset):
def __init__(self, data_json, split, head_feats_dir, opt,
word_to_ix=None, att_to_ix=None, encode_from_phrase=False):
# load the json file which contains info about the dataset
print('Dataset loading data.json: ', data_json)
self.info = json.load(open(data_json))
self.word_to_ix = self.info['word_to_ix'] if word_to_ix is None else word_to_ix
self.ix_to_word = {ix: wd for wd, ix in self.word_to_ix.items()}
self.vocab_size = len(self.ix_to_word)
print('vocab size is ', self.vocab_size)
self.cat_to_ix = self.info['cat_to_ix']
self.ix_to_cat = {ix: cat for cat, ix in self.cat_to_ix.items()}
print('object cateogry size is ', len(self.ix_to_cat))
self.images = self.info['images']
self.anns = self.info['anns']
self.refs = self.info['refs']
self.sentences = self.info['sentences']
print('we have %s images.' % len(self.images))
print('we have %s anns.' % len(self.anns))
print('we have %s refs.' % len(self.refs))
print('we have %s sentences.' % len(self.sentences))
self.label_length = self.info['label_length']
print('label_length is ', self.label_length)
# construct mapping
self.Refs = {ref['ref_id']: ref for ref in self.refs}
self.Images = {image['image_id']: image for image in self.images}
self.Anns = {ann['ann_id']: ann for ann in self.anns}
self.Sentences = {sent['sent_id']: sent for sent in self.sentences}
self.annToRef = {ref['ann_id']: ref for ref in self.refs}
self.sentToRef = {sent_id: ref for ref in self.refs for sent_id in ref['sent_ids']}
# prepare attributes
self.att_to_ix = self.info['att_to_ix'] if att_to_ix is None else att_to_ix
self.ix_to_att = {ix: wd for wd, ix in self.att_to_ix.items()}
self.att_to_cnt = self.info['att_to_cnt']
self.attribute_size = len(self.att_to_ix)
# image_ids of each split
self.split = split
self.split_image_ids = []
for image_id, image in self.Images.items():
if self.Refs[image['ref_ids'][0]]['split'] == self.split:
self.split_image_ids += [image_id]
print('assigned %d images to split %s' % (len(self.split_image_ids), self.split))
# other options
self.seq_per_ref = opt.get('seq_per_ref', 3)
self.sample_ratio = opt['visual_sample_ratio']
self.num_cxt = opt.get('num_cxt', 5)
self.with_st = opt.get('with_st', 1)
self.head_feats_dir = head_feats_dir
self.encode_from_phrase = encode_from_phrase
def __len__(self):
return len(self.split_image_ids)
def __getitem__(self, index):
"""
Args:
index (int): image_id
Returns:
image_id : current index
head : float (1, 1024, H, W)
im_info : ndarray float32 [[im_h, im_w, im_scale]], (1, 3)
image_ann_ids : N annotated objects in this image
ref_ids : n refs, where N = n x seq_per_ref
ref_ann_ids : n positive anns --> ref_pool5, ref_fc7
ref_sent_ids : n positive sent_ids
ref_labels : long (n, label_length)
ref_cxt_ann_ids : (n, num_cxt) ann_ids for each of ref_id, -1 padded
ref_Feats : ref_lfeats : float (n, 5)
ref_dif_lfeats : float (n, 25)
ref_cxt_lfeats : float (n, num_cxt, 5)
neg_ann_ids : n negative anns --> ref_pool5, ref_fc7
neg_sent_ids : n negative sent_ids
neg_labels : long (n, label_length)
neg_cxt_ann_ids : (n, num_cxt) ann_ids for each of neg_id, -1 padded
neg_Feats : neg_lfeats : float (n, 5)
neg_dif_lfeats : float (n, 25)
neg_cxt_lfeats : float (n, num_cxt, 5)
"""
# get image head feats
image_id = index
head, im_info = self.image_to_head(image_id)
head = torch.from_numpy(head)
image_ann_ids = self.Images[image_id]['ann_ids']
# expand ref_ids by seq_per_ref
ref_ids = self.Images[image_id]['ref_ids']
image_ref_ids = self.expand_list(ref_ids, self.seq_per_ref)
# sample all ids
ref_ann_ids, ref_sent_ids = [], []
neg_ann_ids, neg_sent_ids = [], []
for ref_id in ref_ids:
ref_ann_id = self.Refs[ref_id]['ann_id']
# pos ids
ref_ann_ids += [ref_ann_id] * self.seq_per_ref
ref_sent_ids += self.fetch_sent_ids_by_ref_id(ref_id, self.seq_per_ref)
# neg ids
cur_ann_ids, cur_sent_ids = self.sample_neg_ids(ref_ann_id, self.seq_per_ref, self.sample_ratio)
neg_ann_ids += cur_ann_ids
neg_sent_ids += cur_sent_ids
# compute all lfeats
ref_lfeats = torch.from_numpy(self.compute_lfeats(ref_ann_ids))
ref_dif_lfeats = torch.from_numpy(self.compute_dif_lfeats(ref_ann_ids))
neg_lfeats = torch.from_numpy(self.compute_lfeats(neg_ann_ids))
neg_dif_lfeats = torch.from_numpy(self.compute_dif_lfeats(neg_ann_ids))
# fetch labels
ref_labels = torch.from_numpy(self.fetch_labels(ref_sent_ids)).long()
neg_labels = torch.from_numpy(self.fetch_labels(neg_sent_ids)).long()
# fetch context info: cxt_lfeats, cxt_ann_ids
ref_cxt_lfeats, ref_cxt_ann_ids = self.fetch_cxt_info(ref_ann_ids, self.num_cxt)
ref_cxt_lfeats = torch.from_numpy(ref_cxt_lfeats)
neg_cxt_lfeats, neg_cxt_ann_ids = self.fetch_cxt_info(neg_ann_ids, self.num_cxt)
neg_cxt_lfeats = torch.from_numpy(neg_cxt_lfeats)
# return
data = {}
data['image_id'] = image_id
data['head'] = head
data['im_info'] = im_info
data['image_ann_ids'] = image_ann_ids
data['ref_ids'] = image_ref_ids
data['ref_ann_ids'] = ref_ann_ids
data['ref_sent_ids'] = ref_sent_ids
data['ref_labels'] = ref_labels
data['ref_cxt_ann_ids'] = ref_cxt_ann_ids
data['ref_Feats'] = {'lfeats': ref_lfeats, 'dif_lfeats': ref_dif_lfeats, 'cxt_lfeats': ref_cxt_lfeats}
data['neg_ann_ids'] = neg_ann_ids
data['neg_sent_ids'] = neg_sent_ids
data['neg_labels'] = neg_labels
data['neg_cxt_ann_ids'] = neg_cxt_ann_ids
data['neg_Feats'] = {'lfeats': neg_lfeats, 'dif_lfeats': neg_dif_lfeats, 'cxt_lfeats': neg_cxt_lfeats}
return data
# expand list by seq_per_ref, i.e., [a,b], 3 -> [aaabbb]
def expand_list(self, L, n):
out = []
for l in L:
out += [l] * n
return out
def fetch_sent_ids_by_ref_id(self, ref_id, num_sents):
"""
Sample #num_sents sents for each ref_id.
"""
sent_ids = list(self.Refs[ref_id]['sent_ids'])
if len(sent_ids) < num_sents:
append_sent_ids = [random.choice(sent_ids) for _ in range(num_sents - len(sent_ids))]
sent_ids += append_sent_ids
else:
random.shuffle(sent_ids)
sent_ids = sent_ids[:num_sents]
assert len(sent_ids) == num_sents
return sent_ids
def sample_neg_ids(self, ann_id, seq_per_ref, sample_ratio):
"""Return
- neg_ann_ids : list of ann_ids that are negative to target ann_id
- neg_sent_ids: list of sent_ids that are negative to target ann_id
"""
st_ref_ids, st_ann_ids, dt_ref_ids, dt_ann_ids = self.fetch_neighbour_ids(ann_id)
# neg ann
neg_ann_ids, neg_sent_ids = [], []
for k in range(seq_per_ref):
# neg_ann_id for negative visual representation: mainly from same-type objects
if len(st_ann_ids) > 0 and np.random.uniform(0, 1, 1) < sample_ratio:
neg_ann_id = random.choice(st_ann_ids)
elif len(dt_ann_ids) > 0:
neg_ann_id = random.choice(dt_ann_ids)
else:
# awkward case: I just randomly sample from st_ann_ids + dt_ann_ids, or -1
if len(st_ann_ids + dt_ann_ids) > 0:
neg_ann_id = random.choice(st_ann_ids + dt_ann_ids)
else:
neg_ann_id = -1
neg_ann_ids += [neg_ann_id]
# neg_ref_id for negative language representations: mainly from same-type "referred" objects
if len(st_ref_ids) > 0 and np.random.uniform(0, 1, 1) < sample_ratio:
neg_ref_id = random.choice(st_ref_ids)
elif len(dt_ref_ids) > 0:
neg_ref_id = random.choice(dt_ref_ids)
else:
neg_ref_id = random.choice(self.Refs.keys())
neg_sent_id = random.choice(self.Refs[neg_ref_id]['sent_ids'])
neg_sent_ids += [neg_sent_id]
return neg_ann_ids, neg_sent_ids
def fetch_neighbour_ids(self, ref_ann_id):
"""
For a given ref_ann_id, we return
- st_ann_ids: same-type neighbouring ann_ids (not including itself)
- dt_ann_ids: different-type neighbouring ann_ids
Ordered by distance to the input ann_id
"""
ref_ann = self.Anns[ref_ann_id]
x, y, w, h = ref_ann['box']
rx, ry = x+w/2, y+h/2
def compare(ann_id0, ann_id1):
x, y, w, h = self.Anns[ann_id0]['box']
ax0, ay0 = x+w/2, y+h/2
x, y, w, h = self.Anns[ann_id1]['box']
ax1, ay1 = x+w/2, y+h/2
# closer --> former
if (rx-ax0)**2 + (ry-ay0)**2 <= (rx-ax1)**2 + (ry-ay1)**2:
return -1
else:
return 1
image = self.Images[ref_ann['image_id']]
ann_ids = list(image['ann_ids']) # copy in case the raw list is changed
ann_ids = sorted(ann_ids, cmp=compare)
st_ref_ids, st_ann_ids, dt_ref_ids, dt_ann_ids = [], [], [], []
for ann_id in ann_ids:
if ann_id != ref_ann_id:
if self.Anns[ann_id]['category_id'] == ref_ann['category_id']:
st_ann_ids += [ann_id]
if ann_id in self.annToRef:
st_ref_ids += [self.annToRef[ann_id]['ref_id']]
else:
dt_ann_ids += [ann_id]
if ann_id in self.annToRef:
dt_ref_ids += [self.annToRef[ann_id]['ref_id']]
return st_ref_ids, st_ann_ids, dt_ref_ids, dt_ann_ids
def fetch_cxt_info(self, ann_ids, topK):
"""
Return
- cxt_lfeats : ndarray float32 (#ann_ids, topK, 5), padded with 0
- cxt_ann_ids: list[[ann_id]] of size (#ann_ids, topK), padded with -1
Note we only use neighbouring "different"(+"same") objects for computing context objects, zeros padded.
"""
cxt_lfeats = np.zeros((len(ann_ids), topK, 5), dtype=np.float32)
cxt_ann_ids = [[-1 for _ in range(topK)] for _ in range(len(ann_ids))] # (#ann_ids, topK)
for i, ref_ann_id in enumerate(ann_ids):
if ref_ann_id == -1:
continue
# reference box
rbox = self.Anns[ref_ann_id]['box']
rcx, rcy, rw, rh = rbox[0]+rbox[2]/2, rbox[1]+rbox[3]/2, rbox[2], rbox[3]
rw += 1e-5
rh += 1e-5
# candidate boxes
_, st_ann_ids, _, dt_ann_ids = self.fetch_neighbour_ids(ref_ann_id)
if self.with_st > 0:
cand_ann_ids = dt_ann_ids + st_ann_ids
else:
cand_ann_ids = dt_ann_ids
cand_ann_ids = cand_ann_ids[:topK]
for j, cand_ann_id in enumerate(cand_ann_ids):
cand_ann = self.Anns[cand_ann_id]
cbox = cand_ann['box']
cx1, cy1, cw, ch = cbox[0], cbox[1], cbox[2], cbox[3]
cxt_lfeats[i, j, :] = np.array([(cx1-rcx)/rw, (cy1-rcy)/rh, (cx1+cw-rcx)/rw, (cy1+ch-rcy)/rh, cw*ch/(rw*rh)])
cxt_ann_ids[i][j] = cand_ann_id
return cxt_lfeats, cxt_ann_ids
def compute_lfeats(self, ann_ids):
# return ndarray float32 (#ann_ids, 5)
lfeats = np.zeros((len(ann_ids), 5), dtype=np.float32)
for ix, ann_id in enumerate(ann_ids):
if ann_id == -1:
continue
ann = self.Anns[ann_id]
image = self.Images[ann['image_id']]
x, y, w, h = ann['box']
ih, iw = image['height'], image['width']
lfeats[ix] = np.array([[x/iw, y/ih, (x+w-1)/iw, (y+h-1)/ih, w*h/(iw*ih)]], np.float32)
return lfeats
def compute_dif_lfeats(self, ann_ids, topK=5):
# return ndarray float32 (#ann_ids, 5*topK)
dif_lfeats = np.zeros((len(ann_ids), 5*topK), dtype=np.float32)
for i, ref_ann_id in enumerate(ann_ids):
if ref_ann_id == -1:
continue
# reference box
rbox = self.Anns[ref_ann_id]['box']
rcx, rcy, rw, rh = rbox[0]+rbox[2]/2, rbox[1]+rbox[3]/2, rbox[2], rbox[3]
rw += 1e-5
rh += 1e-5
# candidate boxes
_, st_ann_ids, _, _ = self.fetch_neighbour_ids(ref_ann_id)
for j, cand_ann_id in enumerate(st_ann_ids[:topK]):
cbox = self.Anns[cand_ann_id]['box']
cx1, cy1, cw, ch = cbox[0], cbox[1], cbox[2], cbox[3]
dif_lfeats[i, j*5:(j+1)*5] = \
np.array([(cx1-rcx)/rw, (cy1-rcy)/rh, (cx1+cw-rcx)/rw, (cy1+ch-rcy)/rh, cw*ch/(rw*rh)])
return dif_lfeats
# weights = 1/sqrt(cnt)
def get_attribute_weights(self, scale=10):
# return weights for each concept, ordered by cpt_ix
cnts = [self.att_to_cnt[self.ix_to_att[ix]] for ix in range(self.attribute_size)]
cnts = np.array(cnts)
weights = 1/(cnts**0.5 + 1e-5)
weights = (weights - np.min(weights)) / (np.max(weights) - np.min(weights))
weights = weights * (scale-1) + 1
return torch.from_numpy(weights).float()
def fetch_attribute_label(self, ref_ann_ids):
"""Return
- labels : Variable float (N, attribute_size)
- select_ixs: Variable long (N, )
"""
labels = np.zeros((len(ref_ann_ids), self.attribute_size))
select_ixs = []
for i, ref_ann_id in enumerate(ref_ann_ids):
ref = self.annToRef[ref_ann_id]
if len(ref['att_wds']) > 0:
select_ixs += [i]
for wd in ref['att_wds']:
labels[i, self.att_to_ix[wd]] = 1
return Variable(torch.from_numpy(labels).float().cuda()), Variable(torch.LongTensor(select_ixs).cuda())
def image_to_head(self, image_id):
"""Returns
head: float32 (1, 1024, H, W)
im_info: float32 [[im_h, im_w, im_scale]]
"""
feats_h5 = osp.join(self.head_feats_dir, str(image_id)+'.h5')
feats = h5py.File(feats_h5, 'r')
head, im_info = feats['head'], feats['im_info']
return np.array(head), np.array(im_info)
def encode_phrases(self, phrase_list):
# encode to np.int64 (num_phrase_list, label_length)
L = np.zeros((len(phrase_list), self.label_length), dtype=np.int64)
for i, raw_phrase in enumerate(phrase_list):
phrase = preprocess_sent(raw_phrase)
# in case phrase is void
if len(phrase) > 0:
for j, w in enumerate(phrase.split()):
if j < self.label_length:
L[i, j] = self.word_to_ix[w] if w in self.word_to_ix else self.word_to_ix['<UNK>']
else:
L[i, 0] = self.word_to_ix['<UNK>']
return L
def encode_labels(self, tokens_list):
# encode to np.int64 (num_tokens_list, label_length)
L = np.zeros((len(tokens_list), self.label_length), dtype=np.int64)
for i, tokens in enumerate(tokens_list):
for j, w in enumerate(tokens):
if j < self.label_length:
L[i, j] = self.word_to_ix[w]
return L
def fetch_labels(self, sent_ids):
"""
Return: int64 (num_sents, label_length)
"""
if self.encode_from_phrase:
phrase_list = [self.Sentences[sent_id]['sent'] for sent_id in sent_ids]
labels = self.encode_phrases(phrase_list)
else:
tokens_list = [self.Sentences[sent_id]['tokens'] for sent_id in sent_ids]
labels = self.encode_labels(tokens_list)
return labels
def decode_labels(self, labels):
"""
labels: int32 (n, label_length) zeros padded in end
return: list of sents in string format
"""
decoded_sent_strs = []
num_sents = labels.shape[0]
for i in range(num_sents):
label = labels[i].tolist()
sent_str = ' '.join([self.ix_to_word[int(i)] for i in label if i != 0])
decoded_sent_strs.append(sent_str)
return decoded_sent_strs
def fetch_box_feats(self, mrcn, boxes, net_conv, im_info):
"""
Return:
- pool5 (n, 1024, 7, 7)
- fc7 (n, 2048, 7, 7)
"""
pool5, fc7 = mrcn.box_to_spatial_fc7(net_conv, im_info, boxes)
return pool5, fc7
def fetch_feats(self, mrcn, head, im_info, image_id, image_ann_ids, ref_ann_ids, cxt_ann_ids):
"""
Inputs
- mrcn : mask faster-rcnn instance, has to be loaded from outside for the re-use of train/val/test
- head : float (1, 1024, H, W)
- im_info : ndarray float32 [[im_h, im_w, im_scale]], (1, 3)
- image_ann_ids : N annotated ann_ids in the image
- ref_ann_ids : n referred objects
- cxt_ann_ids : (n, num_cxt) context objects
Returns
- ref_fc7 : Variable float (n, 2048, 7, 7)
- ref_pool5 : Variable float (n, 1024, 7, 7)
- cxt_fc7 : Variable float (n, num_cxt, 2048) or None
Dont't forget to detach these feats.
"""
# get all anns' fc7 and pool5
ann_boxes = xywh_to_xyxy(np.vstack([self.Anns[ann_id]['box'] for ann_id in image_ann_ids]))
pool5, fc7 = mrcn.box_to_spatial_fc7(head, im_info, ann_boxes)
# index_select ref_ann_ids
ref_pool5 = Variable(head.data.new(len(ref_ann_ids), 1024, 7, 7).zero_())
ref_fc7 = Variable(head.data.new(len(ref_ann_ids), 2048, 7, 7).zero_())
for i, ann_id in enumerate(ref_ann_ids):
if ann_id == -1:
continue
ix = image_ann_ids.index(ann_id)
ref_pool5[i], ref_fc7[i] = pool5[ix], fc7[ix]
# index_select cxt_fc7
if cxt_ann_ids is not None:
cxt_fc7 = Variable(head.data.new(len(cxt_ann_ids), self.num_cxt, 2048).zero_())
for i in range(len(cxt_ann_ids)):
for j in range(self.num_cxt):
cxt_ann_id = cxt_ann_ids[i][j]
if cxt_ann_id != -1:
ix = image_ann_ids.index(cxt_ann_id)
cxt_fc7[i, j, :] = fc7[ix].mean(2).mean(1)
else:
# we don't use cxt_fc7 when testing attributes
cxt_fc7 = None
# return
return ref_fc7, ref_pool5, cxt_fc7
def getAttributeBatch(self, image_id):
image = self.Images[image_id]
head, im_info = self.image_to_head(image_id)
head = Variable(torch.from_numpy(head).cuda())
# fetch ann_ids owning attributes
ref_ids = image['ref_ids']
ann_ids = [self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]
ann_boxes = xywh_to_xyxy(np.vstack([self.Anns[ann_id]['box'] for ann_id in ann_ids]))
lfeats = Variable(torch.from_numpy(self.compute_lfeats(ann_ids)).cuda())
dif_lfeats = Variable(torch.from_numpy(self.compute_dif_lfeats(ann_ids)).cuda())
# return data
data = {}
data['image_id'] = image_id
data['head'] = head
data['im_info'] = im_info
data['ref_ids'] = ref_ids
data['ann_ids'] = ann_ids
data['Feats'] = {'lfeats': lfeats, 'dif_lfeats': dif_lfeats}
return data
def getImageBatch(self, image_id, sent_ids=None):
"""
Args:
image_id : int
Returns:
image_id : same as input
head : float (1, 1024, H, W)
im_info : ndarray float32 [[im_h, im_w, im_scale]], (1, 3)
ann_ids : N annotated objects in the image
cxt_ann_ids : float (N, num_cxt)
sent_ids : M sent_ids used in this image
Feats : - lfeats float (N, 5)
- dif_lfeats float (N, 25)
- cxt_lfeats float (N, num_cxt, 5)
labels : long (M, label_length)
gd_ixs : list of N ixs
"""
image = self.Images[image_id]
head, im_info = self.image_to_head(image_id)
head = Variable(torch.from_numpy(head).cuda())
ann_ids = self.Images[image_id]['ann_ids']
# compute all lfeats
lfeats = Variable(torch.from_numpy(self.compute_lfeats(ann_ids)).cuda())
dif_lfeats = Variable(torch.from_numpy(self.compute_dif_lfeats(ann_ids)).cuda())
# get cxt info
cxt_lfeats, cxt_ann_ids = self.fetch_cxt_info(ann_ids, self.num_cxt)
cxt_lfeats = Variable(torch.from_numpy(cxt_lfeats).cuda())
# fetch sent_ids and labels
gd_ixs = []
if sent_ids is None:
sent_ids = []
for ref_id in image['ref_ids']:
ref = self.Refs[ref_id]
for sent_id in ref['sent_ids']:
sent_ids += [sent_id]
gd_ixs += [ann_ids.index(ref['ann_id'])]
else:
# given sent_id, we find the gd_ix
for sent_id in sent_ids:
ref = self.sentToRef[sent_id]
gd_ixs += [ann_ids.index(ref['ann_id'])]
labels = Variable(torch.from_numpy(self.fetch_labels(sent_ids)).long().cuda())
# return data
data = {}
data['image_id'] = image_id
data['head'] = head
data['im_info'] = im_info
data['ann_ids'] = ann_ids
data['cxt_ann_ids'] = cxt_ann_ids
data['sent_ids'] = sent_ids
data['gd_ixs'] = gd_ixs
data['Feats'] = {'lfeats': lfeats, 'dif_lfeats': dif_lfeats, 'cxt_lfeats': cxt_lfeats}
data['labels'] = labels
return data
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment