Created
October 9, 2017 05:47
-
-
Save wangheda/69c7771d29561461a33887854e882007 to your computer and use it in GitHub Desktop.
Converting caption annotations into MSCOCO-style reference file (for validation in the image captioning task on challenger.ai)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
# python2.7 | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import json | |
import os.path | |
import random | |
import sys | |
import hashlib | |
reload(sys) | |
sys.setdefaultencoding('utf8') | |
import jieba | |
import numpy as np | |
import tensorflow as tf | |
tf.flags.DEFINE_string("captions_file", "data/ai_challenger_caption_validation_20170910/caption_validation_annotations_20170910.json", | |
"the caption file") | |
tf.flags.DEFINE_string("output_file", "data/ai_challenger_caption_validation_20170910/reference.json", "The output file.") | |
FLAGS = tf.flags.FLAGS | |
def _process_caption_jieba(caption): | |
"""Processes a Chinese caption string into a list of tonenized words. | |
Args: | |
caption: A string caption. | |
Returns: | |
A list of strings; the tokenized caption. | |
""" | |
tokenized_caption = [] | |
tokenized_caption.extend(jieba.cut(caption, cut_all=False)) | |
return tokenized_caption | |
def load_and_process_metadata(captions_file): | |
"""Loads image metadata from a JSON file and processes the captions. | |
Args: | |
captions_file: Json file containing caption annotations. | |
image_dir: Directory containing the image files. | |
Returns: | |
A list of ImageMetadata. | |
""" | |
image_id = set([]) | |
id_to_captions = {} | |
with open(captions_file, 'r') as f: | |
caption_data = json.load(f) | |
for data in caption_data: | |
image_name = data['image_id'].split('.')[0] | |
descriptions = data['caption'] | |
if image_name not in image_id: | |
id_to_captions.setdefault(image_name, []) | |
image_id.add(image_name) | |
caption_num = len(descriptions) | |
for i in range(caption_num): | |
caption_temp = descriptions[i].strip().strip("。").replace('\n', '') | |
if caption_temp != '': | |
id_to_captions[image_name].append(caption_temp) | |
print("Loaded caption metadata for %d images from %s and image_id num is %s" % | |
(len(id_to_captions), captions_file, len(image_id))) | |
# Process the captions and combine the data into a list of ImageMetadata. | |
print("Proccessing captions.") | |
image_metadata = [] | |
num_captions = 0 | |
id = 0 | |
for base_filename in image_id: | |
image_hash = int(int(hashlib.sha256(base_filename).hexdigest(), 16) % sys.maxint) | |
for c in id_to_captions[base_filename]: | |
captions = _process_caption_jieba(c) | |
# captions = c | |
id = id + 1 | |
image_metadata.append((id, image_hash, base_filename, captions)) | |
num_captions += len(captions) | |
print("Finished processing %d captions for %d images in %s" % | |
(num_captions, len(id_to_captions), captions_file)) | |
return image_metadata | |
def write_to_json(image_metadata, output_file): | |
annotations = [] | |
for image_id, image_hash, basename, captions in image_metadata: | |
annotations.append({ | |
"caption": u" ".join(captions), | |
"id": image_id, | |
"image_id": image_hash, | |
}) | |
images = [] | |
for image_id, image_hash, basename, captions in image_metadata: | |
images.append({ | |
"file_name": basename, | |
"id": image_hash, | |
}) | |
results = { | |
"annotations": annotations, | |
"images": images, | |
"type": "captions", | |
"licenses": [ {"url": "https://www.apache.org/licenses/LICENSE-2.0"} ], | |
"info": {"url": ""} | |
} | |
output = open(FLAGS.output_file, 'w') | |
json.dump(results, output, indent=4) | |
output.close() | |
if __name__ == "__main__": | |
image_metadata = load_and_process_metadata(FLAGS.captions_file) | |
write_to_json(image_metadata, FLAGS.output_file) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
# the script directory | |
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" | |
# input directories | |
VALIDATE_CAPTIONS_FILE="${DIR}/../data/ai_challenger_caption_validation_20170910/caption_validation_annotations_20170910.json" | |
# output directories | |
VALIDATE_REFERENCE_FILE="${DIR}/../data/ai_challenger_caption_validation_20170910/reference.json" | |
# run the script | |
python ${DIR}/build_reference_file.py \ | |
--captions_file=$VALIDATE_CAPTIONS_FILE \ | |
--output_file=$VALIDATE_REFERENCE_FILE |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment