Created
July 19, 2019 01:09
-
-
Save lichengunc/deb166b17962d9623eee011e58787702 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
We will load: | |
1) collected data = [{id, dataset, dataset_image_id, file_path, sent, split, bbox(xywh/None)}] | |
2) VQA questions = [{question, question_id, image_id}] | |
3) VG (on coco) questions = [{question, question_id, image_id}] | |
to make a full token_to_ix table. | |
Besides, we use the above to tokenize collected sentences, and make the follows: | |
1) token_to_ix : token -> ix | |
2) split_to_iids : split -> img_ids | |
3) iid_to_sents : img_id -> [{sent, bbox}] | |
4) iid_to_meta : img_id -> {dataset, dataset_image_id, file_path, split} | |
5) img_ids : list of img_ids | |
The above will be saved to cache/prepro.json | |
""" | |
import _init_paths | |
from cfgs.base_cfgs import Cfgs | |
from lib.data.data_utils import tokenize | |
import json | |
import os.path as osp | |
import time | |
import collections | |
import re | |
import argparse | |
def build_vocab(sents, min_token_instances, verbose=True): | |
""" | |
Builds a set that contains the vocab. Filters infrequent tokens. | |
""" | |
vocab = [] | |
vocab.append('PAD') | |
vocab.append('UNK') | |
token_counter = collections.Counter() | |
for sent in sents: | |
words = tokenize(sent) | |
for wd in words: | |
token_counter[wd] += 1 | |
for token, count in token_counter.items(): | |
if count >= min_token_instances: | |
vocab.append(token) | |
if verbose: | |
bad_words = [wd for wd, cnt in token_counter.items() if cnt < min_token_instances] | |
bad_count = sum(cnt for wd, cnt in token_counter.items() if cnt < min_token_instances) | |
total_words = sum(token_counter.values()) | |
print('Keeping %d/%d tokens with enough instances (>=%d).' % (len(vocab), len(token_counter), min_token_instances)) | |
print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(token_counter), len(bad_words)*100./len(token_counter))) | |
print('number of UNK words: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100./total_words)) | |
return vocab | |
def prepro_data(collected, token_to_ix): | |
""" | |
Inputs: | |
- collected : [{id, dataset, dataset_image_id, file_path, sent, split, bbox(xywh/None)}] | |
- token_to_ix : token --> ix | |
Output: | |
- split_to_iids : split -> img_ids | |
- iid_to_sents : img_id -> [{sent, bbox}] | |
- iid_to_meta : img_id -> dataset, dataset_image_id, file_path, split | |
- img_ids | |
Note, img_id = dataset+'_'+dataset_image_id | |
""" | |
iid_to_meta = {} | |
iid_to_sents = collections.OrderedDict() | |
for item in collected: | |
# img_id | |
img_id = item['dataset'] + '_' + str(item['dataset_image_id']) | |
# add to iid_to_sents | |
if img_id not in iid_to_sents: | |
iid_to_sents[img_id] = [] | |
sent = ' '.join(tokenize(item['sent'], token_to_ix)) | |
iid_to_sents[img_id].append({'sent': sent, 'bbox': item['bbox']}) | |
# add to iid_to_meta | |
iid_to_meta[img_id] = {'file_path': item['file_path'], 'split': item['split']} | |
img_ids = list(iid_to_sents.keys()) | |
split_to_iids ={} | |
for img_id in img_ids: | |
split = iid_to_meta[img_id]['split'] | |
if split not in split_to_iids: | |
split_to_iids[split] = [] | |
split_to_iids[split].append(img_id) | |
return split_to_iids, iid_to_sents, iid_to_meta, img_ids | |
def main(args): | |
__C = Cfgs() | |
all_sents = [] | |
# Load collected data | |
tic = time.time() | |
collected = json.load(open(args.collected_json)) | |
print('collected data loaded in %.2f seconds.' % (time.time()-tic)) | |
for data in collected: | |
all_sents.append(data['sent']) | |
# Load VQA data | |
tic = time.time() | |
vqa_data = \ | |
json.load(open(osp.join(__C.DATASET_DIR['vqa'], 'v2_OpenEnded_mscoco_train2014_questions.json'), 'r'))['questions'] + \ | |
json.load(open(osp.join(__C.DATASET_DIR['vqa'], 'v2_OpenEnded_mscoco_val2014_questions.json'), 'r'))['questions'] + \ | |
json.load(open(osp.join(__C.DATASET_DIR['vqa'], 'VG_questions.json'), 'r'))['questions'] | |
print('coco + vg(coco)\'s qa data loaded in %.2f seconds.' % (time.time()-tic)) | |
for ques in vqa_data: | |
all_sents.append(ques['question']) | |
# Tokenize | |
vocab = build_vocab(all_sents, 5) | |
token_to_ix = {wd: ix for ix, wd in enumerate(vocab)} | |
# Preprocess the collected data | |
split_to_iids, iid_to_sents, iid_to_meta, img_ids = prepro_data(collected, token_to_ix) | |
# save | |
# with open(args.token_json, 'w') as io: | |
# json.dump(token_to_ix, io) | |
with open(args.prepro_json, 'w') as io: | |
json.dump({'split_to_iids': split_to_iids, | |
'iid_to_sents': iid_to_sents, | |
'iid_to_meta': iid_to_meta, | |
'img_ids': img_ids, | |
'token_to_ix': token_to_ix}, | |
io) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--collected_json', default='cache/collected.json') | |
parser.add_argument('--min_token_instances', default=5, type=int) | |
parser.add_argument('--prepro_json', default='cache/prepro.json') | |
args = parser.parse_args() | |
main(args) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment