This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--[[ | |
Implement beam search | |
]] | |
function layer:sample_beam(imgs, opt) | |
local beam_size = utils.getopt(opt, 'beam_size', 10) | |
local batch_size, feat_dim = imgs:size(1), imgs:size(2) | |
local function compare(a,b) return a.p > b.p end -- used downstream | |
local function compare_ppl(a, b) return a.ppl < b.ppl end -- used upstream | |
assert(beam_size <= self.vocab_size+1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
from torch.autograd import Variable | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class RNNEncoder(nn.Module): | |
def __init__(self, vocab_size, word_embedding_size, word_vec_size, hidden_size, bidirectional=False, | |
input_dropout_p=0, dropout_p=0, n_layers=1, rnn_type='lstm', variable_lengths=True): | |
super(RNNEncoder, self).__init__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
LanguageRankingCriterion: takes [logp0, logp1] as input computing the ranking loss. | |
""" | |
class LanguageRankingCriterion(nn.Module): | |
def __init__(self, margin=1.): | |
super(LanguageRankingCriterion, self).__init__() | |
self.margin = margin | |
def forward(self, logprobs, target): | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _get_best_yaw_obj_from_pos(self, obj_id, grid_pos, height=1.0, use_iou=True): | |
obj = self.objects[obj_id] | |
obj_fine_class = obj['fine_class'] | |
cx, cy = self.env.house.to_coor(grid_pos[0], grid_pos[1]) | |
self.env.cam.pos.x = cx | |
self.env.cam.pos.y = height | |
self.env.cam.pos.z = cy | |
best_yaw, best_coverage, best_mask = None, 0, None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_best_view_points(h3d, obj_id, args): | |
# obj info | |
obj = h3d.objects[obj_id] | |
h3d.set_target_object(obj) | |
obj_conn_map = h3d.env.house.connMapDict[obj_id][0] | |
obj_point_cands = np.argwhere( (obj_conn_map > args.min_conn_dist) & (obj_conn_map <= args.max_conn_dist) ) | |
# don't search too many for saving time | |
if obj_point_cands.shape[0] > args.num_samples: | |
perm = np.random.permutation(obj_point_cands.shape[0])[:args.num_samples] | |
obj_point_cands = obj_point_cands[perm] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This code will take torch.utils.data.Dataset as abstract class and implement the two important | |
functions: | |
1) __len__ : len(dataset) returns the size of the dataset | |
2) __getitem__: dataset[i] can be used to get i-th sample | |
data_json has | |
0. refs: [{ref_id, ann_id, box, image_id, split, category_id, sent_ids, att_wds}] | |
1. images: [{image_id, ref_ids, file_name, width, height, h5_id}] | |
2. anns: [{ann_id, category_id, image_id, box, h5_id}] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import csv | |
import copy | |
import sys | |
import time | |
import pickle | |
import os | |
import os.path as osp | |
import itertools | |
import numpy as np |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
We will load: | |
1) collected data = [{id, dataset, dataset_image_id, file_path, sent, split, bbox(xywh/None)}] | |
2) VQA questions = [{question, question_id, image_id}] | |
3) VG (on coco) questions = [{question, question_id, image_id}] | |
to make a full token_to_ix table. | |
Besides, we use the above to tokenize collected sentences, and make the follows: | |
1) token_to_ix : token -> ix | |
2) split_to_iids : split -> img_ids |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def tokenize(sent, token_to_ix=None): | |
words = re.sub(r"([.,'!?\"()*#:;])", | |
'', | |
sent.lower() | |
).replace('-', ' ').replace('/', ' ').split() | |
if token_to_ix: | |
return [wd if wd in token_to_ix.keys() else 'UNK' for wd in words] | |
else: | |
return words |