Skip to content

Instantly share code, notes, and snippets.

@bdqnghi
Last active November 2, 2021 17:25
Show Gist options
  • Save bdqnghi/d98a4ba7fb192ce659e50489c9cee0dc to your computer and use it in GitHub Desktop.
Save bdqnghi/d98a4ba7fb192ce659e50489c9cee0dc to your computer and use it in GitHub Desktop.
Split train test val script
import shutil
import os
import random
from random import shuffle
from shutil import copyfile
from concurrent.futures import ThreadPoolExecutor
# ROOT = "/home/nghibui/codes/bi-tbcnn/"
src_dir = "train"
tgt_dir = "train_val"
algo_directories = os.listdir(src_dir)
# names = ["train","test","val"]
# for name in names:
def copy_parts(purpose, files, algo_name):
for file in files:
old_file_path = os.path.join(src_dir, algo_name, file)
new_file_path = os.path.join(tgt_dir, purpose, algo_name, file)
new_dir_path = os.path.join(tgt_dir, purpose, algo)
if not os.path.exists(new_dir_path):
os.makedirs(new_dir_path)
copyfile(old_file_path, new_file_path)
for i, algo in enumerate(algo_directories):
algo_directory = os.path.join(src_dir,algo)
algo_directory_splits = algo_directory.split("/")
files = os.listdir(algo_directory)
shuffle(files)
shuffle(files)
shuffle(files)
# Assume there is 500 files totally for the PKU data
# train =
train_count = int((len(files)*80)/100)
test_count = int((len(files)*20)/100)
valid_count = int((len(files)*0)/100)
print("Num train instances : " + str(train_count))
print("Num test instances : " + str(test_count))
print("Num validation instances : " + str(valid_count))
train_start_index = 0
train_end_index = train_count
valid_start_index = train_end_index
valid_end_index = valid_start_index + valid_count
test_start_index = valid_end_index
test_end_index = len(files)
print(train_start_index, train_end_index)
print(valid_start_index, valid_end_index)
print(test_start_index, test_end_index)
train = files[train_start_index:train_end_index]
test = files[test_start_index:test_end_index]
val = files[valid_start_index:valid_end_index]
copy_parts("train",train,algo)
copy_parts("test",test,algo)
copy_parts("val",val,algo)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment