-
-
Save MkDierz/82084ac2a140a5b2ae6c8fa1ac5c3dc7 to your computer and use it in GitHub Desktop.
Split images randomly over train or validation folder
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2014-2017 Bert Carremans | |
# Author: Bert Carremans <bertcarremans.be> | |
# | |
# License: BSD 3 clause | |
import os | |
import random | |
from shutil import copyfile | |
def img_train_test_split(img_source_dir, train_size): | |
""" | |
Randomly splits images over a train and validation folder, while preserving the folder structure | |
Parameters | |
---------- | |
img_source_dir : string | |
Path to the folder with the images to be split. Can be absolute or relative path | |
train_size : float | |
Proportion of the original images that need to be copied in the subdirectory in the train folder | |
""" | |
if not (isinstance(img_source_dir, str)): | |
raise AttributeError('img_source_dir must be a string') | |
if not os.path.exists(img_source_dir): | |
raise OSError('img_source_dir does not exist') | |
if not (isinstance(train_size, float)): | |
raise AttributeError('train_size must be a float') | |
# Set up empty folder structure if not exists | |
if not os.path.exists('data'): | |
os.makedirs('data') | |
else: | |
if not os.path.exists('data/train'): | |
os.makedirs('data/train') | |
if not os.path.exists('data/validation'): | |
os.makedirs('data/validation') | |
# Get the subdirectories in the main image folder | |
subdirs = [subdir for subdir in os.listdir(img_source_dir) if os.path.isdir(os.path.join(img_source_dir, subdir))] | |
for subdir in subdirs: | |
subdir_fullpath = os.path.join(img_source_dir, subdir) | |
if len(os.listdir(subdir_fullpath)) == 0: | |
print(subdir_fullpath + ' is empty') | |
break | |
train_subdir = os.path.join('data/train', subdir) | |
validation_subdir = os.path.join('data/validation', subdir) | |
# Create subdirectories in train and validation folders | |
if not os.path.exists(train_subdir): | |
os.makedirs(train_subdir) | |
if not os.path.exists(validation_subdir): | |
os.makedirs(validation_subdir) | |
train_counter = 0 | |
validation_counter = 0 | |
# Randomly assign an image to train or validation folder | |
for filename in os.listdir(subdir_fullpath): | |
if filename.endswith(".jpg") or filename.endswith(".png"): | |
fileparts = filename.split('.') | |
if random.uniform(0, 1) <= train_size: | |
copyfile(os.path.join(subdir_fullpath, filename), os.path.join(train_subdir, str(train_counter) + '.' + fileparts[1])) | |
train_counter += 1 | |
else: | |
copyfile(os.path.join(subdir_fullpath, filename), os.path.join(validation_subdir, str(validation_counter) + '.' + fileparts[1])) | |
validation_counter += 1 | |
print('Copied ' + str(train_counter) + ' images to data/train/' + subdir) | |
print('Copied ' + str(validation_counter) + ' images to data/validation/' + subdir) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment