Created
October 12, 2021 03:06
-
-
Save InputBlackBoxOutput/f64011b145833020d770ba7efb913524 to your computer and use it in GitHub Desktop.
Split an image dataset into train, validation and test set for the Keras ImageDataGenerator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import glob | |
import random | |
import shutil | |
# Parse through the respective class directories and make a list of image file paths | |
class_1 = glob.glob("class_1/*.png") | |
class_1 += glob.glob("class_1/*.png") | |
class_2 = glob.glob("class_2/*.png") | |
class_1 += glob.glob("class_1/*.png") | |
images = [class_1, class_2] | |
# Create the train, test and validation folders | |
for folder in ["train", "test","validation"]: | |
os.makedirs(folder, exist_ok=True) | |
for subfolder in ["class_1", "class_2"]: | |
os.makedirs(f"{folder}/{subfolder}", exist_ok=True) | |
# Shuffle, split and copy images | |
for each in images: | |
random.shuffle(each) | |
length = len(each) | |
train = each[:int(length*0.7)] | |
validation = each[int(length*0.7):int(length*0.8)] | |
test = each[int(length*0.8):] | |
for each_image in train: | |
shutil.copy(each_image, os.path.join("train", each_image.split('/', 1)[-1])) | |
for each_image in test: | |
shutil.copy(each_image, os.path.join("test", each_image.split('/', 1)[-1])) | |
for each_image in validation: | |
shutil.copy(each_image, os.path.join("validation", each_image.split('/', 1)[-1])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment