Last active
July 11, 2019 04:10
-
-
Save serihiro/4bdfa7e81f8385d41196a04014bc165b to your computer and use it in GitHub Desktop.
google-landmark tsv file generator https://github.com/cvdfoundation/google-landmark
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import concurrent.futures | |
import argparse | |
import os | |
import glob | |
import csv | |
SUB_DIRECTORIES = list(map(lambda x: str(x), range(0,10) )) + ['a', 'b', 'c', 'd', 'e', 'f'] | |
def generate_list(sub_directory): | |
global label_list | |
sub2_directories = list(map(lambda s: os.path.join(sub_directory, s), SUB_DIRECTORIES)) | |
result = {} | |
for sub2_directory in sub2_directories: | |
sub3_directories = list(map(lambda s: os.path.join(sub2_directory, s), SUB_DIRECTORIES)) | |
for sub3_directory in sub3_directories: | |
base = os.path.join(sub_directory, sub2_directory, sub3_directory) | |
files = glob.glob(f'{base}/*.jpg') | |
for file in files: | |
file_name = file.split(os.sep)[-1] | |
file_id = file_name.split('.')[0] | |
result[file] = label_list[file_id] | |
return result | |
def main(base_directory, label_list_path, concurrency, output_path): | |
sub_directories = list(map(lambda s: os.path.join(base_directory, s), SUB_DIRECTORIES)) | |
global label_list | |
label_list = {} | |
with open(label_list_path, 'r') as f: | |
reader = csv.reader(f, delimiter=',') | |
# skip header | |
next(reader) | |
for line in reader: | |
label_list[line[0]] = line[2] | |
with concurrent.futures.ProcessPoolExecutor(max_workers=concurrency) as executor: | |
results = list(executor.map(generate_list, sub_directories)) | |
with open(output_path, 'w') as f: | |
writer = csv.writer(f, delimiter='\t') | |
for result in results: | |
for key in result: | |
writer.writerow([key, result[key]]) | |
print('done') | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--base_directory', '-b', type=str, required=True) | |
parser.add_argument('--concurrency', '-c', type=int, default=1) | |
parser.add_argument('--label_list_path', '-l', type=str, required=True) | |
parser.add_argument('--output_path', '-o', type=str, required=True) | |
args = parser.parse_args() | |
main(base_directory=args.base_directory, label_list_path=args.label_list_path, | |
concurrency=args.concurrency, output_path=args.output_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment