Created
July 28, 2025 14:17
-
-
Save richardliaw/cff64a6a5551edab2a131fc98acf19d7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
vLLM multi-node deployment script. | |
Automatically handles Ray cluster setup and vLLM server launch. | |
This script simplifies vLLM multi-node deployment by automatically handling the Ray cluster | |
setup and vLLM server launch based on the current node's role, eliminating the need for | |
multiple terminals and manual Ray cluster management. | |
Usage Examples: | |
Basic multi-node deployment: | |
# Head node | |
python3 run.py --head-ip 192.168.1.100 --tp 16 | |
# Worker node | |
python3 run.py --head-ip 192.168.1.100 --tp 16 | |
With environment variables (network interface configuration): | |
# Head node | |
NCCL_SOCKET_IFNAME=bond0 GLOO_SOCKET_IFNAME=bond0 python3 run.py --head-ip 192.168.1.100 --tp 16 | |
# Worker node | |
NCCL_SOCKET_IFNAME=bond0 GLOO_SOCKET_IFNAME=bond0 python3 run.py --head-ip 192.168.1.100 --tp 16 | |
With xpanes for parallel execution: | |
xpanes -I {} -c "NCCL_SOCKET_IFNAME=bond0 GLOO_SOCKET_IFNAME=bond0 python3 run.py --head-ip 192.168.1.100 --tp 16" 192.168.1.100 192.168.1.101 | |
With model and additional vLLM arguments: | |
python3 run.py --head-ip 192.168.1.100 --tp 16 --model meta-llama/Llama-2-70b-hf --vllm-args --max-model-len 4096 --gpu-memory-utilization 0.9 | |
Custom GPU count per node: | |
GPUS_PER_NODE=4 python3 run.py --head-ip 192.168.1.100 --tp 16 | |
Key Features: | |
- Automatic role detection based on IP comparison | |
- Environment variable support for network interface configuration | |
- Ray cluster size validation before starting vLLM server | |
- Proper cleanup and signal handling | |
- Single command per node with consistent arguments | |
""" | |
import subprocess | |
import socket | |
import sys | |
import time | |
import argparse | |
import os | |
import signal | |
import json | |
from typing import Optional, List | |
def get_local_ip() -> str: | |
"""Get the local IP address of this machine.""" | |
try: | |
# Connect to a remote address to determine local IP | |
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: | |
s.connect(("8.8.8.8", 80)) | |
return s.getsockname()[0] | |
except Exception: | |
return "127.0.0.1" | |
def is_port_open(host: str, port: int, timeout: int = 5) -> bool: | |
"""Check if a port is open on a given host.""" | |
try: | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
s.settimeout(timeout) | |
result = s.connect_ex((host, port)) | |
return result == 0 | |
except Exception: | |
return False | |
def wait_for_ray_cluster(head_ip: str, port: int = 6379, timeout: int = 60) -> bool: | |
"""Wait for Ray cluster to be ready.""" | |
print(f"Waiting for Ray cluster at {head_ip}:{port}...") | |
start_time = time.time() | |
while time.time() - start_time < timeout: | |
if is_port_open(head_ip, port): | |
print("Ray cluster is ready!") | |
return True | |
time.sleep(2) | |
print(f"Timeout: Ray cluster at {head_ip}:{port} not ready after {timeout}s") | |
return False | |
def run_command(cmd: List[str], wait: bool = True) -> subprocess.Popen: | |
"""Run a command using current environment variables.""" | |
print(f"Running: {' '.join(cmd)}") | |
if wait: | |
result = subprocess.run(cmd) | |
if result.returncode != 0: | |
print(f"Command failed with return code {result.returncode}") | |
sys.exit(1) | |
return None | |
else: | |
return subprocess.Popen(cmd) | |
def start_ray_head(ray_port: int = 6379) -> subprocess.Popen: | |
"""Start Ray head node.""" | |
print("Starting Ray head node...") | |
cmd = ["ray", "start", "--head", f"--port={ray_port}", "--block"] | |
return run_command(cmd, wait=False) | |
def start_ray_worker(head_ip: str, ray_port: int = 6379) -> subprocess.Popen: | |
"""Start Ray worker node.""" | |
print(f"Starting Ray worker node, connecting to {head_ip}:{ray_port}...") | |
cmd = ["ray", "start", f"--address={head_ip}:{ray_port}", "--block"] | |
return run_command(cmd, wait=False) | |
def get_ray_cluster_size() -> int: | |
"""Get the current Ray cluster size (number of nodes).""" | |
try: | |
# Use ray status to get cluster information | |
result = subprocess.run(["ray", "status", "--format", "json"], | |
capture_output=True, text=True, timeout=10) | |
if result.returncode != 0: | |
print("Warning: Could not get Ray cluster status") | |
return 0 | |
status_data = json.loads(result.stdout) | |
# Count nodes in the cluster | |
node_count = len(status_data.get("cluster_state", {}).get("node_table", {})) | |
return node_count | |
except (subprocess.TimeoutExpired, json.JSONDecodeError, KeyError) as e: | |
print(f"Warning: Error getting Ray cluster size: {e}") | |
return 0 | |
except Exception as e: | |
print(f"Warning: Unexpected error getting Ray cluster size: {e}") | |
return 0 | |
def wait_for_cluster_size(expected_nodes: int, timeout: int = 120) -> bool: | |
"""Wait for Ray cluster to reach expected size.""" | |
print(f"Waiting for Ray cluster to reach {expected_nodes} nodes...") | |
start_time = time.time() | |
while time.time() - start_time < timeout: | |
current_size = get_ray_cluster_size() | |
print(f"Current cluster size: {current_size}/{expected_nodes} nodes") | |
if current_size >= expected_nodes: | |
print(f"Ray cluster ready with {current_size} nodes!") | |
return True | |
time.sleep(5) | |
print(f"Timeout: Ray cluster did not reach {expected_nodes} nodes after {timeout}s") | |
return False | |
def start_vllm_server(tp: int, model: Optional[str] = None, additional_args: Optional[List[str]] = None): | |
"""Start vLLM server.""" | |
print(f"Starting vLLM server with TP={tp}...") | |
cmd = ["vllm", "serve"] | |
if model: | |
cmd.extend([model]) | |
cmd.extend(["-tp", str(tp)]) | |
if additional_args: | |
cmd.extend(additional_args) | |
run_command(cmd, wait=True) | |
def cleanup_ray(): | |
"""Clean up Ray processes.""" | |
print("Cleaning up Ray...") | |
try: | |
subprocess.run(["ray", "stop"], check=False) | |
except Exception as e: | |
print(f"Error during Ray cleanup: {e}") | |
def signal_handler(signum, frame): | |
"""Handle interrupt signals.""" | |
print("\nReceived interrupt signal, cleaning up...") | |
cleanup_ray() | |
sys.exit(0) | |
def parse_env_vars(env_list: List[str]) -> dict: | |
"""Parse environment variables from command line arguments.""" | |
env_vars = {} | |
for env_var in env_list: | |
if '=' in env_var: | |
key, value = env_var.split('=', 1) | |
env_vars[key] = value | |
else: | |
print(f"Warning: Invalid environment variable format: {env_var}") | |
return env_vars | |
def calculate_expected_nodes(tp: int) -> int: | |
"""Calculate expected number of nodes based on TP size and available GPUs per node.""" | |
# Try to get GPU count per node from environment or assume 8 GPUs per node | |
gpus_per_node = int(os.environ.get('GPUS_PER_NODE', '8')) | |
expected_nodes = (tp + gpus_per_node - 1) // gpus_per_node # Ceiling division | |
return max(1, expected_nodes) | |
def main(): | |
parser = argparse.ArgumentParser(description="vLLM multi-node deployment") | |
parser.add_argument("--head-ip", required=True, help="IP address of the head node") | |
parser.add_argument("--tp", type=int, required=True, help="Tensor parallel size") | |
parser.add_argument("--ray-port", type=int, default=6379, help="Ray port (default: 6379)") | |
parser.add_argument("--model", help="Model name/path for vLLM") | |
parser.add_argument("--vllm-args", nargs=argparse.REMAINDER, | |
help="Additional arguments to pass to vLLM serve") | |
parser.add_argument("--wait-timeout", type=int, default=60, | |
help="Timeout for waiting for Ray cluster (seconds)") | |
parser.add_argument("--expected-nodes", type=int, | |
help="Expected number of nodes (auto-calculated if not provided)") | |
parser.add_argument("--skip-cluster-check", action="store_true", | |
help="Skip Ray cluster size validation") | |
args = parser.parse_args() | |
# Set up signal handlers for cleanup | |
signal.signal(signal.SIGINT, signal_handler) | |
signal.signal(signal.SIGTERM, signal_handler) | |
# Get local IP | |
local_ip = get_local_ip() | |
is_head_node = local_ip == args.head_ip | |
print(f"Local IP: {local_ip}") | |
print(f"Head IP: {args.head_ip}") | |
print(f"Node role: {'HEAD' if is_head_node else 'WORKER'}") | |
# Show environment variables that might affect Ray/NCCL | |
env_vars_of_interest = ['NCCL_SOCKET_IFNAME', 'GLOO_SOCKET_IFNAME', 'CUDA_VISIBLE_DEVICES', 'GPUS_PER_NODE'] | |
active_env_vars = {k: v for k, v in os.environ.items() if k in env_vars_of_interest} | |
if active_env_vars: | |
print(f"Relevant environment variables: {active_env_vars}") | |
ray_process = None | |
try: | |
if is_head_node: | |
# Start Ray head node | |
ray_process = start_ray_head(args.ray_port) | |
# Wait a bit for Ray to start | |
time.sleep(5) | |
# Check cluster size before starting vLLM | |
if not args.skip_cluster_check: | |
expected_nodes = args.expected_nodes or calculate_expected_nodes(args.tp) | |
print(f"Expected nodes for TP={args.tp}: {expected_nodes}") | |
if not wait_for_cluster_size(expected_nodes, args.wait_timeout): | |
print("Warning: Cluster size check failed. Use --skip-cluster-check to bypass.") | |
sys.exit(1) | |
# Start vLLM server on head node | |
start_vllm_server(args.tp, args.model, args.vllm_args) | |
else: | |
# Wait for head node to be ready | |
if not wait_for_ray_cluster(args.head_ip, args.ray_port, args.wait_timeout): | |
print("Failed to connect to Ray head node") | |
sys.exit(1) | |
# Start Ray worker node | |
ray_process = start_ray_worker(args.head_ip, args.ray_port) | |
# Keep worker running | |
print("Worker node started. Press Ctrl+C to stop.") | |
try: | |
ray_process.wait() | |
except KeyboardInterrupt: | |
pass | |
except KeyboardInterrupt: | |
print("\nInterrupted by user") | |
except Exception as e: | |
print(f"Error: {e}") | |
sys.exit(1) | |
finally: | |
# Cleanup | |
if ray_process: | |
ray_process.terminate() | |
try: | |
ray_process.wait(timeout=10) | |
except subprocess.TimeoutExpired: | |
ray_process.kill() | |
cleanup_ray() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment