Skip to content

Instantly share code, notes, and snippets.

@PawaritL
Created February 21, 2026 00:58
Show Gist options
  • Select an option

  • Save PawaritL/c10ed525e11374fe351ead89efe64709 to your computer and use it in GitHub Desktop.

Select an option

Save PawaritL/c10ed525e11374fe351ead89efe64709 to your computer and use it in GitHub Desktop.
torchrun on Ray Core
#!/usr/bin/env python3
"""
Run a torchrun command exactly once per Ray worker node.
Key behaviors:
- Discovers all live Ray worker nodes.
- Schedules one Ray task per node using `ray.io/node-id` labels.
- Resets CUDA/GPU-related env vars that Ray may inject.
- Launches torchrun via subprocess with try/except handling.
"""
from __future__ import annotations
import argparse
import os
import shlex
import subprocess
import sys
from typing import Dict, List
import ray
ray.init(ignore_reinit_error=True)
# Ray and accelerator env vars that can constrain visibility and should be reset.
GPU_ENV_VARS_TO_CLEAR = [
"CUDA_VISIBLE_DEVICES",
"ROCR_VISIBLE_DEVICES",
"HIP_VISIBLE_DEVICES",
"NVIDIA_VISIBLE_DEVICES",
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Orchestrate torchrun once per Ray worker node."
)
parser.add_argument(
"--include-head",
action="store_true",
help="Also run on the Ray head node. Default is worker nodes only.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Print commands and scheduling plan without running subprocesses.",
)
parser.add_argument(
"torchrun_args",
nargs=argparse.REMAINDER,
help=(
"Arguments passed to torchrun. Use '--' before these args, e.g. "
"python torchrun_per_ray_node.py -- --nnodes=1 --nproc-per-node=8 train.py"
),
)
return parser.parse_args()
def list_target_nodes(include_head: bool) -> List[Dict[str, str]]:
targets: List[Dict[str, str]] = []
for node in ray.nodes():
if not node.get("Alive", False):
continue
# 'node:__internal_head__' is present on head node resources in Ray.
is_head = "node:__internal_head__" in node.get("Resources", {})
if is_head and not include_head:
continue
targets.append(
{
"node_id": node["NodeID"],
"node_ip": node["NodeManagerAddress"],
"is_head": str(is_head),
}
)
return targets
def sanitized_env() -> Dict[str, str]:
env = os.environ.copy()
for key in GPU_ENV_VARS_TO_CLEAR:
env.pop(key, None)
return env
@ray.remote
def run_torchrun_on_node(
node_id: str,
node_ip: str,
torchrun_args: List[str],
dry_run: bool = False,
) -> Dict[str, str]:
cmd = ["torchrun", *torchrun_args]
cmd_str = shlex.join(cmd)
if dry_run:
return {
"node_id": node_id,
"node_ip": node_ip,
"returncode": "0",
"stdout": "",
"stderr": "",
"cmd": cmd_str,
"status": "dry-run",
}
env = sanitized_env()
try:
result = subprocess.run(
cmd,
env=env,
check=False,
capture_output=True,
text=True,
)
return {
"node_id": node_id,
"node_ip": node_ip,
"returncode": str(result.returncode),
"stdout": result.stdout,
"stderr": result.stderr,
"cmd": cmd_str,
"status": "ok" if result.returncode == 0 else "failed",
}
except Exception as exc:
return {
"node_id": node_id,
"node_ip": node_ip,
"returncode": "-1",
"stdout": "",
"stderr": f"Exception while executing subprocess: {exc}",
"cmd": cmd_str,
"status": "exception",
}
def main() -> int:
args = parse_args()
if args.torchrun_args and args.torchrun_args[0] == "--":
args.torchrun_args = args.torchrun_args[1:]
if not args.torchrun_args:
print("No torchrun args were provided. Pass them after '--'.", file=sys.stderr)
return 2
targets = list_target_nodes(include_head=args.include_head)
if not targets:
print("No matching live Ray nodes found.", file=sys.stderr)
return 1
refs = []
for target in targets:
node_id = target["node_id"]
node_ip = target["node_ip"]
label_selector = {"ray.io/node-id": node_id}
ref = run_torchrun_on_node.options(
num_cpus=0,
label_selector=label_selector,
).remote(
node_id=node_id,
node_ip=node_ip,
torchrun_args=args.torchrun_args,
dry_run=args.dry_run,
)
refs.append(ref)
results = ray.get(refs)
failed = False
for r in results:
print(f"[{r['status']}] node_id={r['node_id']} ip={r['node_ip']}")
print(f" cmd: {r['cmd']}")
if r["stdout"]:
print(" --- stdout ---")
print(r["stdout"].rstrip())
if r["stderr"]:
print(" --- stderr ---")
print(r["stderr"].rstrip())
if r["status"] in ("failed", "exception"):
failed = True
return 1 if failed else 0
if __name__ == "__main__":
raise SystemExit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment