Last active
April 2, 2020 14:37
-
-
Save wbuchwalter/2983221a5a94b98edc787797832e2c96 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pdb | |
import time | |
import argparse | |
import os | |
import csv | |
import json | |
from urllib.request import urlopen | |
from multiprocessing import Pool | |
import numpy as np | |
parser = argparse.ArgumentParser() | |
parser.add_argument('input', help='Path to GCC TSV input file') | |
parser.add_argument('output', help='Output directory') | |
parser.add_argument('-t', help='Number of threads, default: 20', default=20, type=int) | |
parser.add_argument('-f', help='Frequency of checkpoints, default: 1000', default=1000, type=int) | |
args = parser.parse_args() | |
# Trailing slash fuck with basename | |
assert args.output[-1] != '/', "Output path must not end by a trailing slash" | |
# assert args.f % args.b == 0, "Frequency of checkpoint (%i) must be divisible by batch size (%i)" % (args.f, args.b) | |
REQUEST_TIMEOUT = 0.5 # 0.5 sec | |
# { | |
# "images": [{"id": 12, "filename": "00012.jpg"}], | |
# "annotations": [{"id": 12, "image_id": 12, "caption": "some piece of text"}] | |
# } | |
anns = { | |
'annotations': [], | |
'images': [] | |
} | |
cursor_start = 0 | |
cursor_pos = 0 | |
dest_ann_file = os.path.join(args.output, 'captions_{}.json'.format(os.path.basename(args.output))) | |
dest_img_dir = os.path.join(args.output, os.path.basename(args.output)) | |
# If cache file exists, seek over it to find the last processed index | |
# Then set start_cursor to this value. if idx < start_cursor fast forward | |
if os.path.isfile(dest_ann_file): | |
anns = json.load(open(dest_ann_file, 'r')) | |
cursor_start = anns['images'][-1]['id'] | |
print('Output destination already exists, resuming download from image # %i...' % cursor_start) | |
#pdb.set_trace() | |
else: | |
os.makedirs(dest_img_dir, exist_ok=True) | |
def fetch_image_data(url): | |
response = urlopen(url, timeout=REQUEST_TIMEOUT) | |
if response.status != 200: | |
raise Exception('Bad status code') | |
img_data = response.read() | |
header = img_data[:11] | |
if (header[:3] != b'\xff\xd8\xff'): #or (header[6:] != b'JFIF\0'): | |
# If for some reason the header does not look like JPEG (a redirect for a broken image but 200 status for example) we skip | |
raise Exception('Corrupted image') | |
return img_data | |
def process_image(tup): | |
img_id, caption, url = tup | |
try: | |
img_data = fetch_image_data(url) | |
except Exception as e: | |
return None | |
img_filename = "{0:07d}.jpg".format(img_id) | |
open(os.path.join(dest_img_dir, img_filename), 'w+b').write(img_data) | |
ann = {"id": img_id, "image_id": img_id, "caption": caption} | |
img = {"id": img_id, "filename": img_filename} | |
return (ann, img) | |
with open(args.input, 'r') as tsvin: | |
buffer = [] | |
tsvin = csv.reader(tsvin, delimiter='\t') | |
nb_failed = 0 | |
for cursor_pos, (caption, url) in enumerate(tsvin): | |
if cursor_pos < cursor_start: | |
# Fast forward to cursor_start when resuming a download | |
continue | |
# Fill up a batch | |
buffer.append((cursor_pos, caption, url)) | |
# if len(batch) < args.b: | |
# continue | |
processing_cursor = 0 | |
with Pool(args.t) as p: | |
while processing_cursor < len(buffer): | |
t0 = time.time() | |
batch = buffer[processing_cursor : processing_cursor + args.f] | |
res = p.map(process_image, batch) | |
#batch = [] | |
# discard failed downloads (404, invalid headers etc) | |
valid_data = np.array(list(filter(lambda x: x is not None, res))) | |
nb_failed_batch = len(res) - len(valid_data) | |
anns['annotations'].extend(valid_data[:, 0].tolist()) | |
anns['images'].extend(valid_data[:, 1].tolist()) | |
# if (cursor_pos + 1) % args.f == 0: | |
json.dump(anns, open(dest_ann_file, 'w+')) | |
processing_cursor += len(batch) | |
# if cursor_pos > 2000: | |
ti = time.time() - t0 | |
print("Time for a batch:", ti) | |
print("%i images failed to download" % nb_failed_batch) | |
nb_failed += nb_failed_batch | |
print("%i images failed to download" % nb_failed) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment