Created
January 16, 2025 05:00
-
-
Save sadimanna/533dab74e78b20371ab0b4528a7d210d to your computer and use it in GitHub Desktop.
Split miniImageNet data into Train and Test sets (Can also be used for any dataset with class folders inside directory structure)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import shutil | |
import numpy as np | |
import random | |
if __name__ == '__main__': | |
TRAIN_FOLDER = './miniImageNet/train/' | |
TEST_FOLDER = './miniImageNet/test/' | |
TRAIN_SPLIT = 500 | |
TEST_SPLIT = 100 | |
classes = os.listdir('./miniImageNet/') | |
if not os.path.exists(TRAIN_FOLDER): | |
os.makedirs(TRAIN_FOLDER) | |
if not os.path.exists(TEST_FOLDER): | |
os.makedirs(TEST_FOLDER) | |
for c in classes: | |
if not os.path.exists(os.path.join(TRAIN_FOLDER, c)): | |
os.makedirs(os.path.join(TRAIN_FOLDER,c)) | |
if not os.path.exists(os.path.join(TEST_FOLDER, c)): | |
os.makedirs(os.path.join(TEST_FOLDER,c)) | |
files = os.listdir(os.path.join('./miniImageNet',c)) | |
train_files = random.sample(population = files, k = TRAIN_SPLIT) | |
test_files = [f for f in files if f not in train_files] | |
for f in train_files: | |
SRC_PATH = os.path.join('./miniImageNet', c, f) | |
DST_PATH = os.path.join('./miniImageNet/train/', c, f) | |
shutil.move(SRC_PATH, DST_PATH) | |
for f in test_files: | |
SRC_PATH = os.path.join('./miniImageNet',c, f) | |
DST_PATH = os.path.join('./miniImageNet/test/',c, f) | |
shutil.move(SRC_PATH, DST_PATH) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment