Skip to content

Instantly share code, notes, and snippets.

@willprice
Created January 8, 2021 17:19
Show Gist options
  • Save willprice/babe1aa923c889857cbfc9eb65fb2f3d to your computer and use it in GitHub Desktop.
Save willprice/babe1aa923c889857cbfc9eb65fb2f3d to your computer and use it in GitHub Desktop.
# To launch the jobs run
# python multigpu_run.py example_args.py
prog = """
import sys
import os
print(sys.argv[1], os.environ['CUDA_VISIBLE_DEVICES'])
print(sys.argv)
if (sys.argv[1] == '2'):
print("Failed!")
print("BLAH was", os.environ['BLAH'])
sys.exit(1)
"""
PROGRAM = ["python", "-c", prog]
RUN_CONFIGS = [
{
'args': [str(i)],
'env': {'BLAH': 'FOO'}
}
for i in range(10)
]
"""Run a program that uses a single GPU over multiple input arguments"""
import subprocess
from dataclasses import dataclass
from glob import glob
from multiprocessing import Pool
import os.path
import torch
from typing import Any, Dict, Iterable, List, Union
import argparse
from pathlib import Path
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"args_file",
type=Path,
help="Path to python file that when executed will have the following global vars: "
"PROGRAM, and RUN_CONFIGS",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Report STDOUT and STDERR even on successful completion",
)
@dataclass
class RunConfig:
args: Iterable[str]
env: Dict[str, str]
@dataclass
class ProgramConfigs:
program: List[str]
run_configs: List[RunConfig]
def main(args):
gpus: List[str] = get_gpu_ids()
program_configs = load_program_and_args_file(args.args_file)
run(program_configs, gpus, verbose=args.verbose)
def run(program_config: ProgramConfigs, gpus: List[str], verbose: bool = False):
run_program_args = []
for i, run_config in enumerate(program_config.run_configs):
gpu_set_idx = i % len(gpus)
run_gpus = gpus[gpu_set_idx]
run_program_args.append((program_config.program, run_config, run_gpus))
print(
f"Running {len(program_config.run_configs)} configurations with {len(gpus)} instances"
)
with Pool(len(gpus)) as pool:
completed_processes = pool.map(
run_program_wrapper, run_program_args, chunksize=1
)
report_errors(run_program_args, completed_processes, verbose=verbose)
def run_program_wrapper(all_args):
return run_program(*all_args)
def run_program(program: List[str], run_config: RunConfig, gpu_ids: List[str]):
env = os.environ.copy()
env.update(run_config.env)
env.update({"CUDA_VISIBLE_DEVICES": ",".join(gpu_ids)})
completed_process = subprocess.run(
program + list(run_config.args), env=env, check=False, capture_output=True
)
return completed_process
def report_errors(
run_program_args: List[Iterable[Any]],
completed_processes: List[subprocess.CompletedProcess],
verbose: bool = False
) -> None:
n_runs = len(completed_processes)
n_failed = len(list(filter(lambda p: p.returncode != 0, completed_processes)))
for args, completed_process in zip(run_program_args, completed_processes):
failed = completed_process.returncode != 0
if failed or verbose:
print("-" * 120)
if failed:
print("Run failed with params")
print(f"program: {args[0]}")
print(f"run_config: {args[1]}")
print(f"gpu_ids: {args[2]}")
print("STDOUT")
print("=" * 80)
print(completed_process.stdout.decode("utf8"))
print("=" * 80)
print("STDERR")
print("=" * 80)
print(completed_process.stderr.decode("utf8"))
print("=" * 80)
print("-" * 120 + "\n\n\n")
if n_failed > 0:
print(f"{n_failed}/{n_runs} runs failed")
else:
print(f"{n_runs}/{n_runs} runs completed without error.")
def load_program_and_args_file(args_file: Union[Path, str]) -> ProgramConfigs:
program_args_ns = dict()
with open(args_file, "r") as f:
exec(f.read(), program_args_ns)
run_configs: List[Dict[str, Any]] = program_args_ns["RUN_CONFIGS"]
return ProgramConfigs(
program=program_args_ns["PROGRAM"],
run_configs=[
RunConfig(args=run_config.get("args", []), env=run_config.get("env", {}))
for run_config in run_configs
],
)
def get_gpu_ids():
n_gpus = torch.cuda.device_count()
return os.environ.get(
"CUDA_VISIBLE_DEVICES", ",".join(map(str, range(n_gpus)))
).split(",")
if __name__ == "__main__":
main(parser.parse_args())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment