Created
April 6, 2026 02:25
-
-
Save fzyzcjy/2334f0e0d02418fd1ca771ed686269ad to your computer and use it in GitHub Desktop.
Mechanical refactor transform: split miles/ray/rollout.py into miles/ray/rollout/ package
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """Reproducible transform for: split miles/ray/rollout.py into miles/ray/rollout/ package | |
| Run from the repo root: python3 transform.py | |
| """ | |
| import sys | |
| import textwrap | |
| from pathlib import Path | |
| sys.path.append(".claude/skills/mechanical-refactor-verify") | |
| from mechanical_refactor_verify_utils import exec_command, git_add_and_commit | |
| BASE_COMMIT = "4dd7770ed8caf59e45f387c5af7061e5c7e2cc41" | |
| TARGET_COMMIT = "118423b7a" | |
| DIFF_PATHS = [ | |
| "miles/ray/rollout.py", | |
| "miles/ray/rollout/", | |
| ".gitignore", | |
| ] | |
| def _lines(L: list[str], start: int, end: int) -> str: | |
| """Extract lines start..end (1-indexed, inclusive) from L.""" | |
| return "".join(L[start - 1 : end]) | |
| def _dedent4(text: str) -> str: | |
| """Remove exactly 4 leading spaces from each line.""" | |
| out = [] | |
| for line in text.splitlines(keepends=True): | |
| if line.startswith(" "): | |
| out.append(line[4:]) | |
| elif line.strip() == "": | |
| out.append(line) | |
| else: | |
| out.append(line) | |
| return "".join(out) | |
| def transform(dir_root: Path) -> None: | |
| source = dir_root / "miles/ray/rollout.py" | |
| content = source.read_text() | |
| L = content.splitlines(keepends=True) | |
| pkg = dir_root / "miles/ray/rollout" | |
| pkg.mkdir(parents=True, exist_ok=True) | |
| (pkg / "__init__.py").touch() | |
| # === server_group.py === | |
| # ServerGroup class: lines 61-208 | |
| body = _lines(L, 61, 208) | |
| body = body.replace("_allocate_rollout_engine_addr_and_ports_external", "allocate_rollout_engine_addr_and_ports_external") | |
| body = body.replace("_allocate_rollout_engine_addr_and_ports_normal", "allocate_rollout_engine_addr_and_ports_normal") | |
| (pkg / "server_group.py").write_text( | |
| "import dataclasses\n" | |
| "import os\n" | |
| "from typing import Any\n" | |
| "\n" | |
| "import ray\n" | |
| "from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy\n" | |
| "\n" | |
| "from miles.backends.sglang_utils.sglang_engine import SGLangEngine\n" | |
| "from miles.ray.rollout.addr_allocator import (\n" | |
| " allocate_rollout_engine_addr_and_ports_external,\n" | |
| " allocate_rollout_engine_addr_and_ports_normal,\n" | |
| ")\n" | |
| "from miles.ray.utils import NOSET_VISIBLE_DEVICES_ENV_VARS_LIST\n" | |
| "from miles.utils import dumper_utils\n" | |
| "\n" | |
| "\n" | |
| + body | |
| ) | |
| # === addr_allocator.py === | |
| # allocate_rollout_engine_addr_and_ports_normal: lines 810-897 | |
| # allocate_rollout_engine_addr_and_ports_external: lines 796-808 (includes trailing blank line) | |
| normal_body = _lines(L, 810, 897) | |
| normal_body = normal_body.replace("def _allocate_rollout_engine_addr_and_ports_normal", "def allocate_rollout_engine_addr_and_ports_normal") | |
| ext_body = _lines(L, 796, 807) | |
| ext_body = ext_body.replace("def _allocate_rollout_engine_addr_and_ports_external", "def allocate_rollout_engine_addr_and_ports_external") | |
| (pkg / "addr_allocator.py").write_text( | |
| "import logging\n" | |
| "\n" | |
| "import ray\n" | |
| "\n" | |
| "logger = logging.getLogger(__name__)\n" | |
| "\n" | |
| "\n" | |
| + normal_body + "\n\n" | |
| + ext_body | |
| ) | |
| # === router_manager.py === | |
| # start_router: lines 905-964 | |
| # start_session_server: lines 1099-1133 | |
| router_body = _lines(L, 905, 964) | |
| router_body = router_body.replace("def _start_router(", "def start_router(") | |
| session_body = _lines(L, 1099, 1133) | |
| session_body = session_body.replace("def _start_session_server(", "def start_session_server(") | |
| (pkg / "router_manager.py").write_text( | |
| "import logging\n" | |
| "import multiprocessing\n" | |
| "import random\n" | |
| "\n" | |
| "\n" | |
| "from miles.utils.http_utils import (\n" | |
| " _wrap_ipv6,\n" | |
| " find_available_port,\n" | |
| " get_host_info,\n" | |
| " is_port_available,\n" | |
| " wait_for_server_ready,\n" | |
| ")\n" | |
| "\n" | |
| "\n" | |
| "logger = logging.getLogger(__name__)\n" | |
| "\n" | |
| "\n" | |
| + router_body + "\n\n" | |
| + session_body | |
| ) | |
| # === metrics.py === | |
| log_eval = _lines(L, 1136, 1166) | |
| log_eval = log_eval.replace("def _log_eval_rollout_data(", "def log_eval_rollout_data(") | |
| log_eval = log_eval.replace("compute_metrics_from_samples(", "_compute_metrics_from_samples(") | |
| log_rollout = _lines(L, 1169, 1184) | |
| log_rollout = log_rollout.replace("def _log_rollout_data(", "def log_rollout_data(") | |
| log_rollout = log_rollout.replace("compute_metrics_from_samples(", "_compute_metrics_from_samples(") | |
| log_rollout = log_rollout.replace("compute_perf_metrics_from_samples(", "_compute_perf_metrics_from_samples(") | |
| compute_metrics = _lines(L, 1187, 1218) | |
| compute_metrics = compute_metrics.replace("def compute_metrics_from_samples(", "def _compute_metrics_from_samples(") | |
| perf_metrics = _lines(L, 1221, 1251) | |
| perf_metrics = perf_metrics.replace("def compute_perf_metrics_from_samples(", "def _compute_perf_metrics_from_samples(") | |
| zero_std = _lines(L, 1254, 1268) | |
| spec = _lines(L, 1271, 1278) | |
| prefix_cache = _lines(L, 1281, 1289) | |
| reward_cat = _lines(L, 1292, 1299) | |
| (pkg / "metrics.py").write_text( | |
| "import logging\n" | |
| "from typing import Any\n" | |
| "\n" | |
| "import numpy as np\n" | |
| "\n" | |
| "from miles.utils import tracking_utils\n" | |
| "from miles.utils.iter_utils import group_by\n" | |
| "from miles.utils.metric_utils import (\n" | |
| " compute_pass_rate,\n" | |
| " compute_rollout_step,\n" | |
| " compute_statistics,\n" | |
| " dict_add_prefix,\n" | |
| " has_repetition,\n" | |
| ")\n" | |
| "from miles.utils.misc import load_function\n" | |
| "from miles.utils.types import Sample\n" | |
| "\n" | |
| "\n" | |
| "logger = logging.getLogger(__name__)\n" | |
| "\n" | |
| "\n" | |
| + log_eval + "\n\n" | |
| + log_rollout + "\n\n" | |
| + compute_metrics + "\n\n" | |
| + perf_metrics + "\n\n" | |
| + zero_std + "\n\n" | |
| + spec + "\n\n" | |
| + prefix_cache + "\n\n" | |
| + reward_cat | |
| ) | |
| # === debug_data.py === | |
| # _save_debug_rollout_data body: lines 621-637 (method body, needs dedent) | |
| debug_body = _dedent4(_lines(L, 621, 637)) | |
| (pkg / "debug_data.py").write_text( | |
| "import logging\n" | |
| "from pathlib import Path\n" | |
| "\n" | |
| "import torch\n" | |
| "\n" | |
| "logger = logging.getLogger(__name__)\n" | |
| "\n" | |
| "\n" | |
| "# TODO extract `load_debug_rollout_data`\n" | |
| "\n" | |
| "\n" | |
| "# TODO: remove `self`\n" | |
| "def save_debug_rollout_data(self, data, rollout_id, evaluation: bool):\n" | |
| + debug_body | |
| ) | |
| # === train_data_conversion.py === | |
| # These are extracted class methods -> standalone functions, need dedent | |
| # convert_samples_to_train_data body: lines 667-734 | |
| convert_body = _dedent4(_lines(L, 667, 734)) | |
| convert_body = convert_body.replace("self._post_process_rewards(", "_post_process_rewards(self, ") | |
| # _post_process_rewards body: lines 640-664 | |
| post_process_body = _dedent4(_lines(L, 640, 664)) | |
| # split_train_data_by_dp body: lines 740-788 | |
| split_body = _dedent4(_lines(L, 740, 788)) | |
| (pkg / "train_data_conversion.py").write_text( | |
| "import ray\n" | |
| "import torch\n" | |
| "\n" | |
| "from miles.utils.ray_utils import Box\n" | |
| "from miles.utils.seqlen_balancing import get_seqlen_balanced_partitions\n" | |
| "from miles.utils.types import Sample\n" | |
| "\n" | |
| "\n" | |
| "# TODO: remove `self`\n" | |
| "def convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sample]]):\n" | |
| + convert_body | |
| + "\n\n" | |
| "# TODO: remove `self`\n" | |
| "def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]):\n" | |
| + post_process_body | |
| + "\n\n" | |
| "# TODO: remove `self`\n" | |
| "def split_train_data_by_dp(self, data, dp_size):\n" | |
| + split_body | |
| ) | |
| # === rollout_server.py === | |
| # start_rollout_servers: lines 991-1069 | |
| start_servers = _lines(L, 991, 1069) | |
| start_servers = start_servers.replace("_start_router(", "start_router(") | |
| # Forward reference: only in the return type annotation, not in local variables | |
| start_servers = start_servers.replace(") -> dict[str, RolloutServer]:", ') -> dict[str, "RolloutServer"]:') | |
| # _resolve_sglang_config: lines 1072-1091 | |
| resolve_config = _lines(L, 1072, 1091) | |
| # _compute_rollout_offset: lines 967-976 | |
| compute_offset = _lines(L, 967, 976) | |
| # _compute_megatron_num_gpus: lines 979-988 | |
| compute_megatron = _lines(L, 979, 988) | |
| # RolloutServer class: lines 211-325 | |
| rollout_server_class = _lines(L, 211, 325) | |
| (pkg / "rollout_server.py").write_text( | |
| "import dataclasses\n" | |
| "import logging\n" | |
| "\n" | |
| "import ray\n" | |
| "from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS\n" | |
| "\n" | |
| "from miles.backends.sglang_utils.sglang_config import ModelConfig, ServerGroupConfig, SglangConfig\n" | |
| "from miles.ray.rollout.router_manager import start_router\n" | |
| "from miles.ray.rollout.server_group import ServerGroup\n" | |
| "\n" | |
| "logger = logging.getLogger(__name__)\n" | |
| "\n" | |
| "\n" | |
| + start_servers + "\n\n" | |
| + resolve_config + "\n\n" | |
| + compute_offset + "\n\n" | |
| + compute_megatron + "\n\n" | |
| + rollout_server_class | |
| ) | |
| # === rollout_manager.py === | |
| # RolloutManager class: lines 333-618 (through _compute_dynamic_global_batch_size) | |
| # + set_train_parallel_config: lines 736-737 | |
| # Excludes: _save_debug_rollout_data (620-637), _post_process_rewards (639-664), | |
| # _convert_samples_to_train_data (666-734), _split_train_data_by_dp (739-788) | |
| manager_body = _lines(L, 333, 618) | |
| set_train = _lines(L, 736, 737) | |
| manager_body += "\n" + set_train | |
| manager_body = manager_body.replace("self._save_debug_rollout_data(", "save_debug_rollout_data(self, ") | |
| manager_body = manager_body.replace("_log_rollout_data(", "log_rollout_data(") | |
| manager_body = manager_body.replace("_log_eval_rollout_data(", "log_eval_rollout_data(") | |
| manager_body = manager_body.replace("self._convert_samples_to_train_data(", "convert_samples_to_train_data(self, ") | |
| manager_body = manager_body.replace("self._split_train_data_by_dp(", "split_train_data_by_dp(self, ") | |
| manager_body = manager_body.replace("_start_session_server(", "start_session_server(") | |
| manager_body = manager_body.replace( | |
| " def _try_ci_fault_injection(self):", | |
| " # TODO will be replaced by full ft\n def _try_ci_fault_injection(self):", | |
| ) | |
| # Add TODO comment before load_debug_rollout_data | |
| manager_body = manager_body.replace( | |
| " if self.args.load_debug_rollout_data:\n data = torch.load(", | |
| " if self.args.load_debug_rollout_data:\n # TODO extract to `load_debug_rollout_data`\n data = torch.load(", | |
| ) | |
| (pkg / "rollout_manager.py").write_text( | |
| "import itertools\n" | |
| "import logging\n" | |
| "import time\n" | |
| "\n" | |
| "import ray\n" | |
| "import torch\n" | |
| "\n" | |
| "from miles.ray.rollout.debug_data import save_debug_rollout_data\n" | |
| "from miles.ray.rollout.metrics import log_eval_rollout_data, log_rollout_data\n" | |
| "from miles.ray.rollout.rollout_server import RolloutServer, start_rollout_servers\n" | |
| "from miles.ray.rollout.router_manager import start_session_server\n" | |
| "from miles.ray.rollout.train_data_conversion import convert_samples_to_train_data, split_train_data_by_dp\n" | |
| "from miles.ray.utils import Lock\n" | |
| "from miles.rollout.base_types import (\n" | |
| " RolloutFnConstructorInput,\n" | |
| " RolloutFnEvalInput,\n" | |
| " RolloutFnTrainInput,\n" | |
| " call_rollout_fn,\n" | |
| ")\n" | |
| "from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function\n" | |
| "from miles.utils.environ import enable_experimental_rollout_refactor\n" | |
| "from miles.utils.health_monitor import RolloutHealthMonitor\n" | |
| "from miles.utils.http_utils import init_http_client\n" | |
| "from miles.utils.logging_utils import configure_logger\n" | |
| "from miles.utils.metric_checker import MetricChecker\n" | |
| "from miles.utils.misc import load_function\n" | |
| "from miles.utils.tracking_utils import init_tracking\n" | |
| "from miles.utils.types import Sample\n" | |
| "\n" | |
| "logging.getLogger(\"httpx\").setLevel(logging.WARNING)\n" | |
| "logging.getLogger(\"httpcore\").setLevel(logging.WARNING)\n" | |
| "\n" | |
| "\n" | |
| "logger = logging.getLogger(__name__)\n" | |
| "\n" | |
| "\n" | |
| + manager_body | |
| ) | |
| # Remove the original file | |
| source.unlink() | |
| # Fix .gitignore: remove `.claude/` line | |
| gitignore = dir_root / ".gitignore" | |
| gi = gitignore.read_text() | |
| gi = gi.replace(".claude/\n", "") | |
| gitignore.write_text(gi) | |
| git_add_and_commit("split rollout.py into rollout/ package", cwd=str(dir_root)) | |
| def verify() -> None: | |
| """Custom verification that only diffs the relevant paths.""" | |
| import tempfile | |
| repo_root = exec_command("git rev-parse --show-toplevel") | |
| worktree_dir = tempfile.mkdtemp(prefix="verify-mechanical-") | |
| branch_name = f"verify-mechanical-{BASE_COMMIT[:8]}" | |
| try: | |
| print(f"[1/3] Creating worktree at {BASE_COMMIT[:8]}...") | |
| exec_command( | |
| f"git worktree add -b {branch_name} {worktree_dir} {BASE_COMMIT}", | |
| cwd=repo_root, | |
| ) | |
| print("[2/3] Running transformation...") | |
| transform(Path(worktree_dir)) | |
| print(f"[3/3] Diffing against {TARGET_COMMIT[:8]}...") | |
| diff_paths = " ".join(DIFF_PATHS) | |
| diff = exec_command( | |
| f"git diff {TARGET_COMMIT} -- {diff_paths}", | |
| cwd=worktree_dir, | |
| check=False, | |
| ) | |
| if diff: | |
| print(f"\nFAIL: diff is non-empty:\n{diff}") | |
| sys.exit(1) | |
| else: | |
| print("\nPASS: transform reproduces the commit exactly.") | |
| finally: | |
| print(f"\nWorktree left at: {worktree_dir}") | |
| print(f"Branch: {branch_name}") | |
| print("To clean up manually:") | |
| print(f" git worktree remove {worktree_dir} && git branch -D {branch_name}") | |
| if __name__ == "__main__": | |
| verify() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment