Last active
November 5, 2021 09:49
-
-
Save bertcarremans/679624f369ed9270472e37f8333244f5 to your computer and use it in GitHub Desktop.
Split images randomly over train or validation folder
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2014-2017 Bert Carremans | |
# Author: Bert Carremans <bertcarremans.be> | |
# | |
# License: BSD 3 clause | |
import os | |
import random | |
from shutil import copyfile | |
def img_train_test_split(img_source_dir, train_size): | |
""" | |
Randomly splits images over a train and validation folder, while preserving the folder structure | |
Parameters | |
---------- | |
img_source_dir : string | |
Path to the folder with the images to be split. Can be absolute or relative path | |
train_size : float | |
Proportion of the original images that need to be copied in the subdirectory in the train folder | |
""" | |
if not (isinstance(img_source_dir, str)): | |
raise AttributeError('img_source_dir must be a string') | |
if not os.path.exists(img_source_dir): | |
raise OSError('img_source_dir does not exist') | |
if not (isinstance(train_size, float)): | |
raise AttributeError('train_size must be a float') | |
# Set up empty folder structure if not exists | |
if not os.path.exists('data'): | |
os.makedirs('data') | |
else: | |
if not os.path.exists('data/train'): | |
os.makedirs('data/train') | |
if not os.path.exists('data/validation'): | |
os.makedirs('data/validation') | |
# Get the subdirectories in the main image folder | |
subdirs = [subdir for subdir in os.listdir(img_source_dir) if os.path.isdir(os.path.join(img_source_dir, subdir))] | |
for subdir in subdirs: | |
subdir_fullpath = os.path.join(img_source_dir, subdir) | |
if len(os.listdir(subdir_fullpath)) == 0: | |
print(subdir_fullpath + ' is empty') | |
break | |
train_subdir = os.path.join('data/train', subdir) | |
validation_subdir = os.path.join('data/validation', subdir) | |
# Create subdirectories in train and validation folders | |
if not os.path.exists(train_subdir): | |
os.makedirs(train_subdir) | |
if not os.path.exists(validation_subdir): | |
os.makedirs(validation_subdir) | |
train_counter = 0 | |
validation_counter = 0 | |
# Randomly assign an image to train or validation folder | |
for filename in os.listdir(subdir_fullpath): | |
if filename.endswith(".jpg") or filename.endswith(".png"): | |
fileparts = filename.split('.') | |
if random.uniform(0, 1) <= train_size: | |
copyfile(os.path.join(subdir_fullpath, filename), os.path.join(train_subdir, str(train_counter) + '.' + fileparts[1])) | |
train_counter += 1 | |
else: | |
copyfile(os.path.join(subdir_fullpath, filename), os.path.join(validation_subdir, str(validation_counter) + '.' + fileparts[1])) | |
validation_counter += 1 | |
print('Copied ' + str(train_counter) + ' images to data/train/' + subdir) | |
print('Copied ' + str(validation_counter) + ' images to data/validation/' + subdir) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
saved a lot of time
thanks