Last active
May 17, 2022 09:07
-
-
Save Antrikshy/0d14ae39fb1445ffcab0 to your computer and use it in GitHub Desktop.
Simple, configurable Python script to split a single-file dataset into training, testing and validation sets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import math | |
# Configure paths to your dataset files here | |
DATASET_FILE = 'data.csv' | |
FILE_TRAIN = 'train.csv' | |
FILE_VALID = 'validation.csv' | |
FILE_TESTS = 'test.csv' | |
# Set to true if you want to copy first line from main | |
# file into each split (like CSV header) | |
IS_CSV = True | |
# Make sure it adds to 100, no error checking below | |
PERCENT_TRAIN = 50 | |
PERCENT_VALID = 25 | |
PERCENT_TESTS = 25 | |
data = [l for l in open(DATASET_FILE, 'r')] | |
train_file = open(FILE_TRAIN, 'w') | |
valid_file = open(FILE_VALID, 'w') | |
tests_file = open(FILE_TESTS, 'w') | |
if IS_CSV: | |
train_file.write(data[0]) | |
valid_file.write(data[0]) | |
tests_file.write(data[0]) | |
data = data[1:len(data)] | |
num_of_data = len(data) | |
num_train = int((PERCENT_TRAIN/100.0)*num_of_data) | |
num_valid = int((PERCENT_VALID/100.0)*num_of_data) | |
num_tests = int((PERCENT_TESTS/100.0)*num_of_data) | |
data_fractions = [num_train, num_valid, num_tests] | |
split_data = [[],[],[]] | |
rand_data_ind = 0 | |
for split_ind, fraction in enumerate(data_fractions): | |
for i in range(fraction): | |
rand_data_ind = random.randint(0, len(data)-1) | |
split_data[split_ind].append(data[rand_data_ind]) | |
data.pop(rand_data_ind) | |
for l in split_data[0]: | |
train_file.write(l) | |
for l in split_data[1]: | |
valid_file.write(l) | |
for l in split_data[2]: | |
tests_file.write(l) | |
train_file.close() | |
valid_file.close() | |
tests_file.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This should work:
line = data.pop(0)
random.shuffle(data)
data.insert(0, line)