Created
March 3, 2023 18:11
-
-
Save thetonus/cd878b3674c7a950d7afc4397a5dda7a to your computer and use it in GitHub Desktop.
Create train/test/val subsets for a YOLO labeled dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Create train/test/val subsets for a YOLO labeled dataset for CV projects. | |
This is inspired by https://github.com/akashAD98/Train_val_Test_split/blob/main/train_val_test.py. | |
""" | |
import os | |
import random | |
from pathlib import Path | |
from tqdm import tqdm | |
random.seed(0) # For reproducibility | |
IMG_FORMATS = set(['bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm']) | |
def img2label_paths(img_paths: list[str]): | |
"""Define label paths as a function of image paths | |
Args: | |
img_paths: List of image files | |
""" | |
sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings | |
return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths] | |
def main(dataset_dir: Path, weights: list[float], annotated_only: bool): | |
""" Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files | |
Args: | |
dataset_dir: Path to images directory | |
weights: Train, val, test weights (list, tuple) | |
annotated_only: Only use images with an annotated txt file | |
""" | |
weights = [float(x) for x in weights] | |
assert dataset_dir.is_dir(), "Directory not found" | |
assert len(weights) == 3, "Only need 3 values for weights" | |
assert all([1 >= x >= 0 for x in weights]), "All weights need to be between [0, 1]" | |
assert sum(weights) == 1, "All weights need to add up to 1" | |
path = dataset_dir / "images" # images dir | |
files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only | |
n = len(files) # number of files | |
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split | |
txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files | |
[(path.parent / x).unlink(missing_ok=True) for x in txt] # remove existing | |
print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only) | |
for i, img in tqdm(zip(indices, files), total=n): | |
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label | |
with open(path.parent / txt[i], 'a') as f: | |
f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file | |
if __name__ == "__main__": | |
from argparse import ArgumentParser | |
parser = ArgumentParser() | |
parser.add_argument("--dataset-dir", type=Path, required=True, help="Parent directory for YOLO-labeled dataset") | |
parser.add_argument("--weights", nargs="+", default=[0.8, 0.1, 0.1], type=float, help="Dataset weights, (train, test, val)") | |
parser.add_argument("--annotated-only", action="store_true", default=False, help="Only use images with an annotated txt file") | |
kwargs = vars(parser.parse_args()) | |
main(**kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Example usage