Skip to content

Instantly share code, notes, and snippets.

@serihiro
Created March 26, 2019 01:39
Show Gist options
  • Save serihiro/1fb78b775043d9df307f0fb1aa4afa4f to your computer and use it in GitHub Desktop.
Save serihiro/1fb78b775043d9df307f0fb1aa4afa4f to your computer and use it in GitHub Desktop.
ILSVRC2012 dataset subset generator
import os
import sys
import argparse
import shutil
def main(train_root, output, img_per_label):
os.makedirs(output, exist_ok=True)
for label in os.listdir(train_root):
print(label)
input_label_directory = os.path.join(train_root, label)
output_label_directory = os.path.join(output, label)
os.makedirs(output_label_directory, exist_ok=True)
img_count = 1
for img in os.listdir(input_label_directory):
if img_count > img_per_label:
break
shutil.copy2(os.path.join(input_label_directory, img), os.path.join(output_label_directory, img))
img_count += 1
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--train_root', type=str, required=True)
parser.add_argument('--output', type=str, required=True)
parser.add_argument('--img_per_label', type=int, default=1000)
args = parser.parse_args()
main(**vars(args))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment