Created
September 16, 2018 14:00
-
-
Save beatobongco/e66dde2568bafb68d25b3712753a09e4 to your computer and use it in GitHub Desktop.
Split datasets
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import shutil | |
import random | |
from pathlib import Path | |
from typing import List | |
def dataset_splitter(output_dir: str = 'output', classes: List[str] = [], num_train=0, num_validation=0): | |
"""In the current directory, take files of a class and randomly copy a certain number of training examples | |
and validation examples into a new output directory. | |
TODO: add support for classes in folders already to begin with | |
""" | |
nums = { | |
'train': num_train, | |
'validation': num_validation | |
} | |
in_path = Path.cwd() | |
# check for a sibling folder of our in_path named 1000 | |
out_path = in_path / '..' / output_dir | |
# create the path if it doesn't exist | |
out_path.mkdir(exist_ok=True) | |
iterators = {} | |
# create iterators so we don't end up using the same images for train and test | |
# this is inefficient because casting to list will hold the paths contents in memory | |
for cls in classes: | |
images = list(in_path.glob('{0}*'.format(cls))) | |
random.shuffle(images) | |
iterators[cls] = iter(images) | |
for dataset_type in ('train', 'validation'): | |
for cls in classes: | |
class_directory = out_path / dataset_type / cls | |
class_directory.mkdir(parents=True, exist_ok=True) | |
for index, image_path in enumerate(iterators[cls]): | |
print(index, str(class_directory)) | |
if index == nums[dataset_type]: | |
break | |
shutil.copy(str(image_path), str(class_directory)) | |
""" | |
OUTPUT: | |
train/ | |
cat/ | |
...1k random cat images | |
dog/ | |
... | |
validation/ | |
cat/ | |
...400 random cat images | |
dog/ | |
... | |
""" | |
dataset_splitter(output_dir='cats_vs_dogs_1k', classes=['cat', 'dog'], num_train=1000, num_validation=400) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment