Last active
February 21, 2020 16:05
-
-
Save wbuchwalter/d65fc12fd19a6af7f98988a00b0c7ad0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import csv | |
import json | |
from urllib.request import urlopen | |
from multiprocessing import Pool | |
import numpy as np | |
parser = argparse.ArgumentParser() | |
parser.add_argument('input', help='Path to GCC TSV input file') | |
parser.add_argument('output', help='Output directory') | |
parser.add_argument('-t', help='Number of threads, default: 20', default=20, type=int) | |
parser.add_argument('-f', help='Frequency of checkpoints, default: 1000', default=1000, type=int) | |
args = parser.parse_args() | |
# Trailing slash mess with basename | |
assert args.output[-1] != '/', "Output path must not end by a trailing slash" | |
REQUEST_TIMEOUT = 0.5 # 0.5 sec | |
# COCO format is like: | |
# { | |
# "images": [{"id": 12, "filename": "00012.jpg"}], | |
# "annotations": [{"id": 12, "image_id": 12, "caption": "some piece of text"}] | |
# } | |
anns = { | |
'annotations': [], | |
'images': [] | |
} | |
cursor_start = 0 | |
cursor_pos = 0 | |
dest_ann_file = os.path.join(args.output, 'captions_{}.json'.format(os.path.basename(args.output))) | |
dest_img_dir = os.path.join(args.output, os.path.basename(args.output)) | |
# If cache file exists, seek over it to find the last processed index | |
# Then set start_cursor to this value. if idx < start_cursor fast forward | |
if os.path.isfile(dest_ann_file): | |
anns = json.load(open(dest_ann_file, 'r')) | |
cursor_start = anns['images'][-1]['id'] | |
print('Output destination already exists, resuming download from image # %i...' % cursor_start) | |
else: | |
os.makedirs(dest_img_dir, exist_ok=True) | |
def fetch_image_data(url): | |
response = urlopen(url, timeout=REQUEST_TIMEOUT) | |
if response.status != 200: | |
raise Exception('Bad status code') | |
img_data = response.read() | |
header = img_data[:11] | |
if (header[:3] != b'\xff\xd8\xff'): | |
# If for some reason the header does not look like JPEG (a redirect for a broken image but 200 status for example) we skip | |
raise Exception('Corrupted image') | |
return img_data | |
def process_image(tup): | |
img_id, caption, url = tup | |
try: | |
img_data = fetch_image_data(url) | |
except Exception as e: | |
return None | |
img_filename = "{0:07d}.jpg".format(img_id) | |
open(os.path.join(dest_img_dir, img_filename), 'w+b').write(img_data) | |
ann = {"id": img_id, "image_id": img_id, "caption": caption} | |
img = {"id": img_id, "file_name": img_filename} | |
return (ann, img) | |
with open(args.input, 'r') as tsvin: | |
buffer = [] | |
tsvin = csv.reader(tsvin, delimiter='\t') | |
for cursor_pos, (caption, url) in enumerate(tsvin): | |
if cursor_pos < cursor_start: | |
# Fast forward to cursor_start when resuming a download | |
continue | |
# Fill up a batch | |
buffer.append((cursor_pos, caption, url)) | |
processing_cursor = 0 | |
nb_failed = 0 | |
total_processing_time = 0 | |
with Pool(args.t) as p: | |
while processing_cursor < len(buffer): | |
t0 = time.time() | |
batch = buffer[processing_cursor : processing_cursor + args.f] | |
res_iterator = p.imap_unordered(process_image, batch) | |
valid_data = [] | |
nb_failed_batch = 0 | |
for res in res_iterator: | |
if res is None: # Happens if there was an issue downloading the image | |
nb_failed_batch += 1 | |
else: | |
valid_data.append(res) | |
valid_data = np.array(valid_data) | |
anns['annotations'].extend(valid_data[:, 0].tolist()) | |
anns['images'].extend(valid_data[:, 1].tolist()) | |
json.dump(anns, open(dest_ann_file, 'w+')) | |
processing_cursor += len(batch) | |
batch_time = time.time() - t0 | |
total_processing_time += batch_time | |
t_per_sample = total_processing_time / processing_cursor | |
abs_pos = processing_cursor + cursor_start | |
eta = t_per_sample * (3300000 - abs_pos) / 3600 | |
print("[Step]: {}, [Batch Time]: {:.1f}s., [ETA]: {:.2f} hours".format(abs_pos, batch_time, eta)) | |
print("%i images failed to download over %i for this batch" % (nb_failed_batch, len(batch))) | |
nb_failed += nb_failed_batch | |
print("%i images failed to download" % nb_failed) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment