Skip to content

Instantly share code, notes, and snippets.

@thetonus
Created March 3, 2023 18:11
Show Gist options
  • Save thetonus/cd878b3674c7a950d7afc4397a5dda7a to your computer and use it in GitHub Desktop.
Save thetonus/cd878b3674c7a950d7afc4397a5dda7a to your computer and use it in GitHub Desktop.
Create train/test/val subsets for a YOLO labeled dataset
""" Create train/test/val subsets for a YOLO labeled dataset for CV projects.
This is inspired by https://github.com/akashAD98/Train_val_Test_split/blob/main/train_val_test.py.
"""
import os
import random
from pathlib import Path
from tqdm import tqdm
random.seed(0) # For reproducibility
IMG_FORMATS = set(['bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm'])
def img2label_paths(img_paths: list[str]):
"""Define label paths as a function of image paths
Args:
img_paths: List of image files
"""
sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
def main(dataset_dir: Path, weights: list[float], annotated_only: bool):
""" Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
Args:
dataset_dir: Path to images directory
weights: Train, val, test weights (list, tuple)
annotated_only: Only use images with an annotated txt file
"""
weights = [float(x) for x in weights]
assert dataset_dir.is_dir(), "Directory not found"
assert len(weights) == 3, "Only need 3 values for weights"
assert all([1 >= x >= 0 for x in weights]), "All weights need to be between [0, 1]"
assert sum(weights) == 1, "All weights need to add up to 1"
path = dataset_dir / "images" # images dir
files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
n = len(files) # number of files
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
[(path.parent / x).unlink(missing_ok=True) for x in txt] # remove existing
print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
for i, img in tqdm(zip(indices, files), total=n):
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
with open(path.parent / txt[i], 'a') as f:
f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--dataset-dir", type=Path, required=True, help="Parent directory for YOLO-labeled dataset")
parser.add_argument("--weights", nargs="+", default=[0.8, 0.1, 0.1], type=float, help="Dataset weights, (train, test, val)")
parser.add_argument("--annotated-only", action="store_true", default=False, help="Only use images with an annotated txt file")
kwargs = vars(parser.parse_args())
main(**kwargs)
@thetonus
Copy link
Author

thetonus commented Mar 3, 2023

Example usage

$ python split.py --help
usage: train_test_split.py [-h] --dataset-dir DATASET_DIR
                           [--weights WEIGHTS [WEIGHTS ...]]
                           [--annotated-only]

options:
  -h, --help            show this help message and exit
  --dataset-dir DATASET_DIR
                        Parent directory for YOLO-labeled dataset
  --weights WEIGHTS [WEIGHTS ...]
                        Dataset weights, (train, test, val)
  --annotated-only      Only use images with an annotated txt file

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment