Created
February 21, 2026 00:58
-
-
Save PawaritL/c10ed525e11374fe351ead89efe64709 to your computer and use it in GitHub Desktop.
torchrun on Ray Core
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Run a torchrun command exactly once per Ray worker node. | |
| Key behaviors: | |
| - Discovers all live Ray worker nodes. | |
| - Schedules one Ray task per node using `ray.io/node-id` labels. | |
| - Resets CUDA/GPU-related env vars that Ray may inject. | |
| - Launches torchrun via subprocess with try/except handling. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import shlex | |
| import subprocess | |
| import sys | |
| from typing import Dict, List | |
| import ray | |
| ray.init(ignore_reinit_error=True) | |
| # Ray and accelerator env vars that can constrain visibility and should be reset. | |
| GPU_ENV_VARS_TO_CLEAR = [ | |
| "CUDA_VISIBLE_DEVICES", | |
| "ROCR_VISIBLE_DEVICES", | |
| "HIP_VISIBLE_DEVICES", | |
| "NVIDIA_VISIBLE_DEVICES", | |
| "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", | |
| ] | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Orchestrate torchrun once per Ray worker node." | |
| ) | |
| parser.add_argument( | |
| "--include-head", | |
| action="store_true", | |
| help="Also run on the Ray head node. Default is worker nodes only.", | |
| ) | |
| parser.add_argument( | |
| "--dry-run", | |
| action="store_true", | |
| help="Print commands and scheduling plan without running subprocesses.", | |
| ) | |
| parser.add_argument( | |
| "torchrun_args", | |
| nargs=argparse.REMAINDER, | |
| help=( | |
| "Arguments passed to torchrun. Use '--' before these args, e.g. " | |
| "python torchrun_per_ray_node.py -- --nnodes=1 --nproc-per-node=8 train.py" | |
| ), | |
| ) | |
| return parser.parse_args() | |
| def list_target_nodes(include_head: bool) -> List[Dict[str, str]]: | |
| targets: List[Dict[str, str]] = [] | |
| for node in ray.nodes(): | |
| if not node.get("Alive", False): | |
| continue | |
| # 'node:__internal_head__' is present on head node resources in Ray. | |
| is_head = "node:__internal_head__" in node.get("Resources", {}) | |
| if is_head and not include_head: | |
| continue | |
| targets.append( | |
| { | |
| "node_id": node["NodeID"], | |
| "node_ip": node["NodeManagerAddress"], | |
| "is_head": str(is_head), | |
| } | |
| ) | |
| return targets | |
| def sanitized_env() -> Dict[str, str]: | |
| env = os.environ.copy() | |
| for key in GPU_ENV_VARS_TO_CLEAR: | |
| env.pop(key, None) | |
| return env | |
| @ray.remote | |
| def run_torchrun_on_node( | |
| node_id: str, | |
| node_ip: str, | |
| torchrun_args: List[str], | |
| dry_run: bool = False, | |
| ) -> Dict[str, str]: | |
| cmd = ["torchrun", *torchrun_args] | |
| cmd_str = shlex.join(cmd) | |
| if dry_run: | |
| return { | |
| "node_id": node_id, | |
| "node_ip": node_ip, | |
| "returncode": "0", | |
| "stdout": "", | |
| "stderr": "", | |
| "cmd": cmd_str, | |
| "status": "dry-run", | |
| } | |
| env = sanitized_env() | |
| try: | |
| result = subprocess.run( | |
| cmd, | |
| env=env, | |
| check=False, | |
| capture_output=True, | |
| text=True, | |
| ) | |
| return { | |
| "node_id": node_id, | |
| "node_ip": node_ip, | |
| "returncode": str(result.returncode), | |
| "stdout": result.stdout, | |
| "stderr": result.stderr, | |
| "cmd": cmd_str, | |
| "status": "ok" if result.returncode == 0 else "failed", | |
| } | |
| except Exception as exc: | |
| return { | |
| "node_id": node_id, | |
| "node_ip": node_ip, | |
| "returncode": "-1", | |
| "stdout": "", | |
| "stderr": f"Exception while executing subprocess: {exc}", | |
| "cmd": cmd_str, | |
| "status": "exception", | |
| } | |
| def main() -> int: | |
| args = parse_args() | |
| if args.torchrun_args and args.torchrun_args[0] == "--": | |
| args.torchrun_args = args.torchrun_args[1:] | |
| if not args.torchrun_args: | |
| print("No torchrun args were provided. Pass them after '--'.", file=sys.stderr) | |
| return 2 | |
| targets = list_target_nodes(include_head=args.include_head) | |
| if not targets: | |
| print("No matching live Ray nodes found.", file=sys.stderr) | |
| return 1 | |
| refs = [] | |
| for target in targets: | |
| node_id = target["node_id"] | |
| node_ip = target["node_ip"] | |
| label_selector = {"ray.io/node-id": node_id} | |
| ref = run_torchrun_on_node.options( | |
| num_cpus=0, | |
| label_selector=label_selector, | |
| ).remote( | |
| node_id=node_id, | |
| node_ip=node_ip, | |
| torchrun_args=args.torchrun_args, | |
| dry_run=args.dry_run, | |
| ) | |
| refs.append(ref) | |
| results = ray.get(refs) | |
| failed = False | |
| for r in results: | |
| print(f"[{r['status']}] node_id={r['node_id']} ip={r['node_ip']}") | |
| print(f" cmd: {r['cmd']}") | |
| if r["stdout"]: | |
| print(" --- stdout ---") | |
| print(r["stdout"].rstrip()) | |
| if r["stderr"]: | |
| print(" --- stderr ---") | |
| print(r["stderr"].rstrip()) | |
| if r["status"] in ("failed", "exception"): | |
| failed = True | |
| return 1 if failed else 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment