Skip to content

Instantly share code, notes, and snippets.

@babldev
Last active January 10, 2024 02:29
Show Gist options
  • Save babldev/69cb2ea67e521036634673ba4a1989e7 to your computer and use it in GitHub Desktop.
Save babldev/69cb2ea67e521036634673ba4a1989e7 to your computer and use it in GitHub Desktop.
Script for partitioning training, test, validation data files randomly (e.g. Yolo object detection)
#!/usr/bin/python3
"""
Randomly bucket files in a directory based on percentages for the purposes of machine learning training.
E.g. `python3 partition-data.py --data labels:**/*.txt images:**/*.png --partitions train:80 val:10 test:10
Will move files in the current directory to:
- train/ (80%)
- val/ (10%)
- test/ (10%)
"""
import os
import shutil
import argparse
import random
import functools
from dataclasses import dataclass
from typing import List, Optional
from glob import glob
@dataclass
class Partition:
name: str
percentage: int
@dataclass
class DataGroup:
name: str
files: List[str]
@dataclass
class FileAssignment:
name: str
partition_name: Optional[str] = None
def filename_for_path(path: str) -> str:
return os.path.splitext(os.path.basename(path))[0]
def move_file(src, dest):
# Create the destination directory if it doesn't exist
os.makedirs(os.path.dirname(dest), exist_ok=True)
# Move the file
shutil.move(src, dest)
def distribute_files(working_dir: str, partitions: List[Partition], data_groups: List[DataGroup]):
# Find shared filenames across data groups
filename_sets = list(map(
lambda data_group: set(map(
filename_for_path,
data_group.files
)),
data_groups
))
common_filenames = filename_sets[0]
common_filenames = functools.reduce(lambda a, b: a.intersection(b), filename_sets[1:], common_filenames)
random.shuffle(list(common_filenames))
common_filenames = list(map(lambda filename: FileAssignment(filename), common_filenames))
# Calculate the number of files for each target based on the percentage
total_files = len(common_filenames)
distribution = {p.name: int(p.percentage * total_files / 100) for p in partitions}
# Adjust the last target to take the remainder to avoid rounding issues
last_target = partitions[-1].name
distribution[last_target] = total_files - sum(distribution.values()) + distribution[last_target]
common_filenames_iter = iter(common_filenames)
# Distribute files
for name, count in distribution.items():
for _ in range(count):
file = next(common_filenames_iter)
file.partition_name = name
# Index file assignments by name
assignments = {a.name: a.partition_name for a in common_filenames}
# Make directories
for partition in partitions:
# Make directories for each partition
target_dir = os.path.join(working_dir, partition.name)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
for data_group in data_groups:
for file_path in data_group.files:
assignment = assignments.get(filename_for_path(file_path))
if not assignment:
continue
from_path = os.path.join(working_dir, file_path)
to_path = os.path.join(working_dir, assignment, file_path)
# print(f'Move {from_path} -> {to_path}')
move_file(from_path, to_path)
return distribution
def main():
parser = argparse.ArgumentParser(description='Distribute files into buckets.')
parser.add_argument('-w', '--working-directory', type=str, nargs='?', default=os.getcwd(), dest='working_directory',
help='Directory with files to distribute (default: current directory)')
parser.add_argument('-d', '--data', nargs='+', type=str, help='Data types and globs (e.g. labels:**/*.txt, images:**/*.png')
parser.add_argument('-p', '--partitions', nargs='+', type=str, help='Target directories and percentages (e.g., train:70 val:30)')
args = parser.parse_args()
# Parse target directories and percentages
partitions = []
for partition in args.partitions:
name, perc = partition.split(':')
partitions.append(Partition(name, int(perc)))
# Data groups
data_groups = []
for group in args.data:
name, regex = group.split(':')
data_groups.append(DataGroup(name, glob(regex)))
# Distribute files
distribution = distribute_files(
working_dir=args.working_directory,
partitions=partitions,
data_groups=data_groups,
)
# Output results
print("Distribution of files:")
for target, count in distribution.items():
print(f"{target}: {count} files")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment