Skip to content

Instantly share code, notes, and snippets.

@ugai
Created November 23, 2022 03:26
Show Gist options
  • Save ugai/25e2e43bfaf1ead186ec95d83885ec6c to your computer and use it in GitHub Desktop.
Save ugai/25e2e43bfaf1ead186ec95d83885ec6c to your computer and use it in GitHub Desktop.
import os
import shutil
from argparse import ArgumentParser, Namespace
import numpy as np
from loguru import logger
def print_args(args: Namespace):
for arg in vars(args):
logger.debug(f"args.{arg}: {getattr(args, arg)}")
def get_folder_names(parent_dir: str) -> list[str]:
return [
s for s in os.listdir(parent_dir) if os.path.isdir(os.path.join(parent_dir, s))
]
def get_file_names(parent_dir: str) -> list[str]:
return [
s for s in os.listdir(parent_dir) if os.path.isfile(os.path.join(parent_dir, s))
]
def create_target_and_child_dirs(target_dir: str, child_names: list[str]):
os.makedirs(target_dir, exist_ok=True)
for name in child_names:
os.makedirs(os.path.join(target_dir, name), exist_ok=True)
def clamp(value, min_value, max_value):
return max(min(value, max_value), min_value)
def clamp01(value):
return clamp(value, 0.0, 1.0)
def main():
# fmt: off
parser = ArgumentParser()
parser.add_argument("--input-dir", type=str, required=True)
parser.add_argument("--output-dir", type=str, required=True)
parser.add_argument("--train-ratio", type=float, default=0.7)
args = parser.parse_args()
print_args(args)
# fmt: on
if not os.path.isdir(args.input_dir):
raise OSError("input-dir not found.")
file_names_for_class: dict[str, list[str]] = {}
class_names: list[str] = get_folder_names(args.input_dir)
class_count: int = len(class_names)
for class_name in class_names:
class_dir: str = os.path.join(args.input_dir, class_name)
file_names_for_class[class_name] = get_file_names(class_dir)
train_dir: str = os.path.join(args.output_dir, "train")
test_dir: str = os.path.join(args.output_dir, "test")
create_target_and_child_dirs(train_dir, class_names)
create_target_and_child_dirs(test_dir, class_names)
for ci, class_name in enumerate(class_names):
file_names: list[str] = file_names_for_class[class_name]
file_count: int = len(file_names)
train_file_count: int = int(file_count * clamp01(args.train_ratio))
logger.info(
f"class_name='{class_name}', total={file_count}, train={train_file_count}, test={file_count - train_file_count}"
)
np.random.shuffle(file_names)
for fi, file_name in enumerate(file_names):
dst_dir: str = train_dir if fi < train_file_count else test_dir
src_file: str = os.path.join(args.input_dir, class_name, file_name)
dst_file: str = os.path.join(dst_dir, class_name, file_name)
logger.trace(
f"copy(class={ci+1}/{class_count}, file={fi+1}/{train_file_count}): '{src_file}' -> '{dst_file}'"
)
shutil.copyfile(src_file, dst_file)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment