Last active
January 10, 2024 02:29
-
-
Save babldev/69cb2ea67e521036634673ba4a1989e7 to your computer and use it in GitHub Desktop.
Script for partitioning training, test, validation data files randomly (e.g. Yolo object detection)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
""" | |
Randomly bucket files in a directory based on percentages for the purposes of machine learning training. | |
E.g. `python3 partition-data.py --data labels:**/*.txt images:**/*.png --partitions train:80 val:10 test:10 | |
Will move files in the current directory to: | |
- train/ (80%) | |
- val/ (10%) | |
- test/ (10%) | |
""" | |
import os | |
import shutil | |
import argparse | |
import random | |
import functools | |
from dataclasses import dataclass | |
from typing import List, Optional | |
from glob import glob | |
@dataclass | |
class Partition: | |
name: str | |
percentage: int | |
@dataclass | |
class DataGroup: | |
name: str | |
files: List[str] | |
@dataclass | |
class FileAssignment: | |
name: str | |
partition_name: Optional[str] = None | |
def filename_for_path(path: str) -> str: | |
return os.path.splitext(os.path.basename(path))[0] | |
def move_file(src, dest): | |
# Create the destination directory if it doesn't exist | |
os.makedirs(os.path.dirname(dest), exist_ok=True) | |
# Move the file | |
shutil.move(src, dest) | |
def distribute_files(working_dir: str, partitions: List[Partition], data_groups: List[DataGroup]): | |
# Find shared filenames across data groups | |
filename_sets = list(map( | |
lambda data_group: set(map( | |
filename_for_path, | |
data_group.files | |
)), | |
data_groups | |
)) | |
common_filenames = filename_sets[0] | |
common_filenames = functools.reduce(lambda a, b: a.intersection(b), filename_sets[1:], common_filenames) | |
random.shuffle(list(common_filenames)) | |
common_filenames = list(map(lambda filename: FileAssignment(filename), common_filenames)) | |
# Calculate the number of files for each target based on the percentage | |
total_files = len(common_filenames) | |
distribution = {p.name: int(p.percentage * total_files / 100) for p in partitions} | |
# Adjust the last target to take the remainder to avoid rounding issues | |
last_target = partitions[-1].name | |
distribution[last_target] = total_files - sum(distribution.values()) + distribution[last_target] | |
common_filenames_iter = iter(common_filenames) | |
# Distribute files | |
for name, count in distribution.items(): | |
for _ in range(count): | |
file = next(common_filenames_iter) | |
file.partition_name = name | |
# Index file assignments by name | |
assignments = {a.name: a.partition_name for a in common_filenames} | |
# Make directories | |
for partition in partitions: | |
# Make directories for each partition | |
target_dir = os.path.join(working_dir, partition.name) | |
if not os.path.exists(target_dir): | |
os.makedirs(target_dir) | |
for data_group in data_groups: | |
for file_path in data_group.files: | |
assignment = assignments.get(filename_for_path(file_path)) | |
if not assignment: | |
continue | |
from_path = os.path.join(working_dir, file_path) | |
to_path = os.path.join(working_dir, assignment, file_path) | |
# print(f'Move {from_path} -> {to_path}') | |
move_file(from_path, to_path) | |
return distribution | |
def main(): | |
parser = argparse.ArgumentParser(description='Distribute files into buckets.') | |
parser.add_argument('-w', '--working-directory', type=str, nargs='?', default=os.getcwd(), dest='working_directory', | |
help='Directory with files to distribute (default: current directory)') | |
parser.add_argument('-d', '--data', nargs='+', type=str, help='Data types and globs (e.g. labels:**/*.txt, images:**/*.png') | |
parser.add_argument('-p', '--partitions', nargs='+', type=str, help='Target directories and percentages (e.g., train:70 val:30)') | |
args = parser.parse_args() | |
# Parse target directories and percentages | |
partitions = [] | |
for partition in args.partitions: | |
name, perc = partition.split(':') | |
partitions.append(Partition(name, int(perc))) | |
# Data groups | |
data_groups = [] | |
for group in args.data: | |
name, regex = group.split(':') | |
data_groups.append(DataGroup(name, glob(regex))) | |
# Distribute files | |
distribution = distribute_files( | |
working_dir=args.working_directory, | |
partitions=partitions, | |
data_groups=data_groups, | |
) | |
# Output results | |
print("Distribution of files:") | |
for target, count in distribution.items(): | |
print(f"{target}: {count} files") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment