Skip to content

Instantly share code, notes, and snippets.

@InputBlackBoxOutput
Created October 12, 2021 03:06
Show Gist options
  • Save InputBlackBoxOutput/f64011b145833020d770ba7efb913524 to your computer and use it in GitHub Desktop.
Save InputBlackBoxOutput/f64011b145833020d770ba7efb913524 to your computer and use it in GitHub Desktop.
Split an image dataset into train, validation and test set for the Keras ImageDataGenerator
import os
import glob
import random
import shutil
# Parse through the respective class directories and make a list of image file paths
class_1 = glob.glob("class_1/*.png")
class_1 += glob.glob("class_1/*.png")
class_2 = glob.glob("class_2/*.png")
class_1 += glob.glob("class_1/*.png")
images = [class_1, class_2]
# Create the train, test and validation folders
for folder in ["train", "test","validation"]:
os.makedirs(folder, exist_ok=True)
for subfolder in ["class_1", "class_2"]:
os.makedirs(f"{folder}/{subfolder}", exist_ok=True)
# Shuffle, split and copy images
for each in images:
random.shuffle(each)
length = len(each)
train = each[:int(length*0.7)]
validation = each[int(length*0.7):int(length*0.8)]
test = each[int(length*0.8):]
for each_image in train:
shutil.copy(each_image, os.path.join("train", each_image.split('/', 1)[-1]))
for each_image in test:
shutil.copy(each_image, os.path.join("test", each_image.split('/', 1)[-1]))
for each_image in validation:
shutil.copy(each_image, os.path.join("validation", each_image.split('/', 1)[-1]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment