Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Created January 16, 2025 05:00
Show Gist options
  • Save sadimanna/533dab74e78b20371ab0b4528a7d210d to your computer and use it in GitHub Desktop.
Save sadimanna/533dab74e78b20371ab0b4528a7d210d to your computer and use it in GitHub Desktop.
Split miniImageNet data into Train and Test sets (Can also be used for any dataset with class folders inside directory structure)
import os
import shutil
import numpy as np
import random
if __name__ == '__main__':
TRAIN_FOLDER = './miniImageNet/train/'
TEST_FOLDER = './miniImageNet/test/'
TRAIN_SPLIT = 500
TEST_SPLIT = 100
classes = os.listdir('./miniImageNet/')
if not os.path.exists(TRAIN_FOLDER):
os.makedirs(TRAIN_FOLDER)
if not os.path.exists(TEST_FOLDER):
os.makedirs(TEST_FOLDER)
for c in classes:
if not os.path.exists(os.path.join(TRAIN_FOLDER, c)):
os.makedirs(os.path.join(TRAIN_FOLDER,c))
if not os.path.exists(os.path.join(TEST_FOLDER, c)):
os.makedirs(os.path.join(TEST_FOLDER,c))
files = os.listdir(os.path.join('./miniImageNet',c))
train_files = random.sample(population = files, k = TRAIN_SPLIT)
test_files = [f for f in files if f not in train_files]
for f in train_files:
SRC_PATH = os.path.join('./miniImageNet', c, f)
DST_PATH = os.path.join('./miniImageNet/train/', c, f)
shutil.move(SRC_PATH, DST_PATH)
for f in test_files:
SRC_PATH = os.path.join('./miniImageNet',c, f)
DST_PATH = os.path.join('./miniImageNet/test/',c, f)
shutil.move(SRC_PATH, DST_PATH)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment