Skip to content

Instantly share code, notes, and snippets.

@richardliaw
Created July 28, 2025 14:17
Show Gist options
  • Save richardliaw/cff64a6a5551edab2a131fc98acf19d7 to your computer and use it in GitHub Desktop.
Save richardliaw/cff64a6a5551edab2a131fc98acf19d7 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
vLLM multi-node deployment script.
Automatically handles Ray cluster setup and vLLM server launch.
This script simplifies vLLM multi-node deployment by automatically handling the Ray cluster
setup and vLLM server launch based on the current node's role, eliminating the need for
multiple terminals and manual Ray cluster management.
Usage Examples:
Basic multi-node deployment:
# Head node
python3 run.py --head-ip 192.168.1.100 --tp 16
# Worker node
python3 run.py --head-ip 192.168.1.100 --tp 16
With environment variables (network interface configuration):
# Head node
NCCL_SOCKET_IFNAME=bond0 GLOO_SOCKET_IFNAME=bond0 python3 run.py --head-ip 192.168.1.100 --tp 16
# Worker node
NCCL_SOCKET_IFNAME=bond0 GLOO_SOCKET_IFNAME=bond0 python3 run.py --head-ip 192.168.1.100 --tp 16
With xpanes for parallel execution:
xpanes -I {} -c "NCCL_SOCKET_IFNAME=bond0 GLOO_SOCKET_IFNAME=bond0 python3 run.py --head-ip 192.168.1.100 --tp 16" 192.168.1.100 192.168.1.101
With model and additional vLLM arguments:
python3 run.py --head-ip 192.168.1.100 --tp 16 --model meta-llama/Llama-2-70b-hf --vllm-args --max-model-len 4096 --gpu-memory-utilization 0.9
Custom GPU count per node:
GPUS_PER_NODE=4 python3 run.py --head-ip 192.168.1.100 --tp 16
Key Features:
- Automatic role detection based on IP comparison
- Environment variable support for network interface configuration
- Ray cluster size validation before starting vLLM server
- Proper cleanup and signal handling
- Single command per node with consistent arguments
"""
import subprocess
import socket
import sys
import time
import argparse
import os
import signal
import json
from typing import Optional, List
def get_local_ip() -> str:
"""Get the local IP address of this machine."""
try:
# Connect to a remote address to determine local IP
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
except Exception:
return "127.0.0.1"
def is_port_open(host: str, port: int, timeout: int = 5) -> bool:
"""Check if a port is open on a given host."""
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(timeout)
result = s.connect_ex((host, port))
return result == 0
except Exception:
return False
def wait_for_ray_cluster(head_ip: str, port: int = 6379, timeout: int = 60) -> bool:
"""Wait for Ray cluster to be ready."""
print(f"Waiting for Ray cluster at {head_ip}:{port}...")
start_time = time.time()
while time.time() - start_time < timeout:
if is_port_open(head_ip, port):
print("Ray cluster is ready!")
return True
time.sleep(2)
print(f"Timeout: Ray cluster at {head_ip}:{port} not ready after {timeout}s")
return False
def run_command(cmd: List[str], wait: bool = True) -> subprocess.Popen:
"""Run a command using current environment variables."""
print(f"Running: {' '.join(cmd)}")
if wait:
result = subprocess.run(cmd)
if result.returncode != 0:
print(f"Command failed with return code {result.returncode}")
sys.exit(1)
return None
else:
return subprocess.Popen(cmd)
def start_ray_head(ray_port: int = 6379) -> subprocess.Popen:
"""Start Ray head node."""
print("Starting Ray head node...")
cmd = ["ray", "start", "--head", f"--port={ray_port}", "--block"]
return run_command(cmd, wait=False)
def start_ray_worker(head_ip: str, ray_port: int = 6379) -> subprocess.Popen:
"""Start Ray worker node."""
print(f"Starting Ray worker node, connecting to {head_ip}:{ray_port}...")
cmd = ["ray", "start", f"--address={head_ip}:{ray_port}", "--block"]
return run_command(cmd, wait=False)
def get_ray_cluster_size() -> int:
"""Get the current Ray cluster size (number of nodes)."""
try:
# Use ray status to get cluster information
result = subprocess.run(["ray", "status", "--format", "json"],
capture_output=True, text=True, timeout=10)
if result.returncode != 0:
print("Warning: Could not get Ray cluster status")
return 0
status_data = json.loads(result.stdout)
# Count nodes in the cluster
node_count = len(status_data.get("cluster_state", {}).get("node_table", {}))
return node_count
except (subprocess.TimeoutExpired, json.JSONDecodeError, KeyError) as e:
print(f"Warning: Error getting Ray cluster size: {e}")
return 0
except Exception as e:
print(f"Warning: Unexpected error getting Ray cluster size: {e}")
return 0
def wait_for_cluster_size(expected_nodes: int, timeout: int = 120) -> bool:
"""Wait for Ray cluster to reach expected size."""
print(f"Waiting for Ray cluster to reach {expected_nodes} nodes...")
start_time = time.time()
while time.time() - start_time < timeout:
current_size = get_ray_cluster_size()
print(f"Current cluster size: {current_size}/{expected_nodes} nodes")
if current_size >= expected_nodes:
print(f"Ray cluster ready with {current_size} nodes!")
return True
time.sleep(5)
print(f"Timeout: Ray cluster did not reach {expected_nodes} nodes after {timeout}s")
return False
def start_vllm_server(tp: int, model: Optional[str] = None, additional_args: Optional[List[str]] = None):
"""Start vLLM server."""
print(f"Starting vLLM server with TP={tp}...")
cmd = ["vllm", "serve"]
if model:
cmd.extend([model])
cmd.extend(["-tp", str(tp)])
if additional_args:
cmd.extend(additional_args)
run_command(cmd, wait=True)
def cleanup_ray():
"""Clean up Ray processes."""
print("Cleaning up Ray...")
try:
subprocess.run(["ray", "stop"], check=False)
except Exception as e:
print(f"Error during Ray cleanup: {e}")
def signal_handler(signum, frame):
"""Handle interrupt signals."""
print("\nReceived interrupt signal, cleaning up...")
cleanup_ray()
sys.exit(0)
def parse_env_vars(env_list: List[str]) -> dict:
"""Parse environment variables from command line arguments."""
env_vars = {}
for env_var in env_list:
if '=' in env_var:
key, value = env_var.split('=', 1)
env_vars[key] = value
else:
print(f"Warning: Invalid environment variable format: {env_var}")
return env_vars
def calculate_expected_nodes(tp: int) -> int:
"""Calculate expected number of nodes based on TP size and available GPUs per node."""
# Try to get GPU count per node from environment or assume 8 GPUs per node
gpus_per_node = int(os.environ.get('GPUS_PER_NODE', '8'))
expected_nodes = (tp + gpus_per_node - 1) // gpus_per_node # Ceiling division
return max(1, expected_nodes)
def main():
parser = argparse.ArgumentParser(description="vLLM multi-node deployment")
parser.add_argument("--head-ip", required=True, help="IP address of the head node")
parser.add_argument("--tp", type=int, required=True, help="Tensor parallel size")
parser.add_argument("--ray-port", type=int, default=6379, help="Ray port (default: 6379)")
parser.add_argument("--model", help="Model name/path for vLLM")
parser.add_argument("--vllm-args", nargs=argparse.REMAINDER,
help="Additional arguments to pass to vLLM serve")
parser.add_argument("--wait-timeout", type=int, default=60,
help="Timeout for waiting for Ray cluster (seconds)")
parser.add_argument("--expected-nodes", type=int,
help="Expected number of nodes (auto-calculated if not provided)")
parser.add_argument("--skip-cluster-check", action="store_true",
help="Skip Ray cluster size validation")
args = parser.parse_args()
# Set up signal handlers for cleanup
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# Get local IP
local_ip = get_local_ip()
is_head_node = local_ip == args.head_ip
print(f"Local IP: {local_ip}")
print(f"Head IP: {args.head_ip}")
print(f"Node role: {'HEAD' if is_head_node else 'WORKER'}")
# Show environment variables that might affect Ray/NCCL
env_vars_of_interest = ['NCCL_SOCKET_IFNAME', 'GLOO_SOCKET_IFNAME', 'CUDA_VISIBLE_DEVICES', 'GPUS_PER_NODE']
active_env_vars = {k: v for k, v in os.environ.items() if k in env_vars_of_interest}
if active_env_vars:
print(f"Relevant environment variables: {active_env_vars}")
ray_process = None
try:
if is_head_node:
# Start Ray head node
ray_process = start_ray_head(args.ray_port)
# Wait a bit for Ray to start
time.sleep(5)
# Check cluster size before starting vLLM
if not args.skip_cluster_check:
expected_nodes = args.expected_nodes or calculate_expected_nodes(args.tp)
print(f"Expected nodes for TP={args.tp}: {expected_nodes}")
if not wait_for_cluster_size(expected_nodes, args.wait_timeout):
print("Warning: Cluster size check failed. Use --skip-cluster-check to bypass.")
sys.exit(1)
# Start vLLM server on head node
start_vllm_server(args.tp, args.model, args.vllm_args)
else:
# Wait for head node to be ready
if not wait_for_ray_cluster(args.head_ip, args.ray_port, args.wait_timeout):
print("Failed to connect to Ray head node")
sys.exit(1)
# Start Ray worker node
ray_process = start_ray_worker(args.head_ip, args.ray_port)
# Keep worker running
print("Worker node started. Press Ctrl+C to stop.")
try:
ray_process.wait()
except KeyboardInterrupt:
pass
except KeyboardInterrupt:
print("\nInterrupted by user")
except Exception as e:
print(f"Error: {e}")
sys.exit(1)
finally:
# Cleanup
if ray_process:
ray_process.terminate()
try:
ray_process.wait(timeout=10)
except subprocess.TimeoutExpired:
ray_process.kill()
cleanup_ray()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment