Skip to content

Instantly share code, notes, and snippets.

@hytopoulos
Created August 24, 2025 21:22
Show Gist options
  • Save hytopoulos/603c84b48c56ef2e174ed33bfd9ce71c to your computer and use it in GitHub Desktop.
Save hytopoulos/603c84b48c56ef2e174ed33bfd9ce71c to your computer and use it in GitHub Desktop.
import os
from pathlib import Path
import modal
import re
import subprocess
from typing import Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import partial
from typing import Dict
app = modal.App("grpo-verl")
modal.enable_output()
PATH_TO_REWARD_FUNCTION: Path = Path("/root/grpo_verl.py")
REWARD_FUNCTION_NAME: str = "reward_diff"
MODELS_DIR = Path("/models")
checkpoints_volume: modal.Volume = modal.Volume.from_name("checkpoints", create_if_missing=True)
CACHE_DIR = Path("/cache")
cache_volume: modal.Volume = modal.Volume.from_name("cache", create_if_missing=True)
DATA_PATH = Path("/data")
data_volume: modal.Volume = modal.Volume.from_name("data", create_if_missing=True)
if not os.path.exists(DATA_PATH / "std_includes.txt"):
with data_volume.batch_upload(DATA_PATH) as batch:
batch.put_file("std_includes.txt", "/std_includes.txt")
DATA_SOURCE = "asm_decompile" # <-- name you’ll also use to route your reward fn
ABILITY = "code/decompile" # free-form task tag
compile_image = (
modal.Image.from_registry("debian:stable-20250721-slim", add_python='3.11')
.apt_install(["g++-10"])
.pip_install("datasets")
.pip_install("cxxfilt", "levenshtein", "tqdm", "pandas")
)
VERL_REPO_PATH: Path = Path("/root/verl")
image = (
modal.Image.from_registry("verlai/verl:app-verl0.5-vllm0.10.0-mcore0.13.0-te2.2")
.apt_install("git")
.run_commands(f"git clone https://github.com/volcengine/verl {VERL_REPO_PATH}")
.run_commands(f"cd {VERL_REPO_PATH} && pip install --no-deps -e .")
.apt_install('g++-10')
.pip_install("protobuf==4.25.3")
.pip_install('cxxfilt', 'levenshtein', 'tqdm', 'datasets', 'pandas')
# .pip_install("flash_attn==2.5.8")
# .run_commands("pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.2/flash_attn-2.8.2+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl")
)
with image.imports():
from datasets import Dataset, load_dataset
import os
import re
import subprocess
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import partial
from typing import Dict, Optional
import pandas as pd
import cxxfilt
import Levenshtein
from tqdm.auto import tqdm
import shlex
from verl import DataProto
from verl.utils.reward_score import default_compute_score
from verl.workers.reward_manager import register
from verl.workers.reward_manager.abstract import AbstractRewardManager
with compile_image.imports():
from datasets import Dataset, load_dataset
import os
import re
import subprocess
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import partial
from typing import Dict, Optional
import pandas as pd
import cxxfilt
import Levenshtein
from tqdm.auto import tqdm
import shlex
from typing import Iterable, Tuple
# ╔═══ 1. Compile-and-split helper ════════════════════════════════════════════╗
PAT_OBJDUMP = re.compile(r'^[0-9a-f]+\s+<([\w.$@]+)>:')
PAT_TYPE = re.compile(r'^\s*\.type\s+([\w.$@]+),\s*@function')
PAT_LABEL = re.compile(r'^\s*([\w.$@]+):')
_FENCE_RE = re.compile(r'```(?:cpp|c\+\+)?\s*(.*?)```', re.S)
_DIRECTIVE_RE = re.compile(
r'^\s*@|^\s*\.(?!L)'
)
def _safe_demangle(name: str) -> str:
try:
return cxxfilt.demangle(name, external_only=False)
except cxxfilt.InvalidName:
return name
def _want(demangled: str) -> bool:
return not (
demangled.startswith(("std::", "__gnu_cxx::"))
or demangled.startswith("_ZSt")
)
def strip_directives(asm: str) -> str:
"""Remove assembler directive/comment lines beginning with '@'."""
return '\n'.join(
line for line in asm.splitlines()
if not _DIRECTIVE_RE.match(line)
)
def compile_and_split(sample: dict, *, sample_id: int) -> Optional[Dict[str, str]]:
"""Compile a single TU to assembly and split into per-function snippets."""
with open("/data/std_includes.txt", "r") as f:
includes = f.read()
if not isinstance(sample, str):
src = sample["text"].replace("#include <bits/stdc++.h>", "")
else:
src = sample.replace("#include <bits/stdc++.h>", "")
tu = f"{includes}\n{src}"
flags = [
"-O2", "-std=c++17",
"-fno-verbose-asm", "-fno-asynchronous-unwind-tables",
"-fno-stack-protector", "-fno-ident", "-g0",
"-fno-inline-functions", "-fno-inline-functions-called-once",
"-fno-implicit-templates", "-fno-rtti", "-fno-exceptions",
]
flags_q = " ".join(map(shlex.quote, flags))
# if 0:
res = subprocess.run(
["g++", *flags, "-x", "c++", "-", "-S", "-o", "-"],
input=tu.encode(), stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
if res.returncode:
return None
asm_text = res.stdout.decode("utf-8", "replace")
asm_text = strip_directives(asm_text)
funcs, current, buf = {}, None, []
for line in asm_text.splitlines(keepends=True):
# strip directives
if _DIRECTIVE_RE.match(line):
continue
m = PAT_OBJDUMP.match(line) or PAT_LABEL.match(line) or PAT_TYPE.match(line)
if m:
name = m.group(1)
if name.strip().startswith('.L'):
if current: # (don't start a new function)
buf.append(line)
continue
if current:
funcs[current] = ''.join(buf)
current, buf = name, [line]
elif current:
buf.append(line)
if current:
funcs[current] = ''.join(buf)
# remove funcs with no body
funcs = {k: v for k, v in funcs.items() if len(v.splitlines()) > 2}
out = {}
for raw, body in funcs.items():
demangled = _safe_demangle(raw)
if _want(demangled):
out[f"{sample_id}:{demangled}"] = body
return out
# ╔═══ 4. Reward utilities ════════════════════════════════════════════════════╗
# ── 4.a Helpers ──────────────────────────────────────────────────────────────
def extract_code(txt: str) -> str:
"""
Return only the C++ that should be compiled:
The first fenced ``` block (with or without cpp/c++ tag).
"""
m = _FENCE_RE.search(txt)
if m:
return m.group(1)
return None
def dispatch_compile_and_split(
jobs: Iterable[Tuple[int, str]],
*,
max_workers: Optional[int] = None,
progress: bool = False,
) -> Dict[int, Optional[Dict[str, str]]]:
"""
Run compile_and_split(src, sample_id=sid) over many (sid, src) jobs in parallel.
Returns: { sid: { "sid:func_name": asm, ... } | None }
(None means the compile/split failed or raised.)
"""
max_workers = max_workers or (os.cpu_count() or 1)
results: Dict[int, Optional[Dict[str, str]]] = {}
with ThreadPoolExecutor(max_workers=max_workers) as pool:
fut_to_sid = {
pool.submit(compile_and_split, src, sample_id=sid): sid
for sid, src in jobs
}
iterator = as_completed(fut_to_sid)
if progress:
iterator = tqdm(iterator, total=len(fut_to_sid), desc="Compiling")
for fut in iterator:
sid = fut_to_sid[fut]
try:
results[sid] = fut.result()
except Exception as e:
# Optional: log the exception for debugging
# print(f"[dispatch] sample_id={sid} failed: {e}")
results[sid] = None
return results
# --- Reward function (now parallelized via helper) ---------------------------
def reward_diff(solution_strs, ground_truths=None, data_sources=None, abilities=None, reward_models=None, extra_infos=None):
"""
Return one reward per model completion in `solution_strs`.
"""
n_comp = len(solution_strs)
rewards = []
no_code = 0
comp_fail = 0
empty_funcs = 0
# Build compile jobs from MODEL COMPLETIONS
jobs = []
for idx, completion in enumerate(solution_strs):
src = extract_code(completion) # completions should have ```...```
if src is None:
no_code += 1
else:
jobs.append((idx, src))
compiled = dispatch_compile_and_split(jobs, max_workers=os.cpu_count(), progress=False)
for idx in range(n_comp):
# 1) Reference ASM: prefer reward_models[idx]['ground_truth']
ref_asm = None
if reward_models and idx < len(reward_models):
rm = reward_models[idx]
if isinstance(rm, dict):
ref_asm = rm.get("ground_truth")
# 2) Fallbacks if needed
if ref_asm is None and ground_truths:
# If your dataloader repeats the GT per completion, this still works
ref_asm = ground_truths[idx if idx < len(ground_truths) else -1]
if ref_asm is None:
# As a last resort, keep old prompt slicing, but it’s brittle:
prompt = solution_strs[idx]
nl = prompt.find('\n')
ref_asm = prompt[nl+1:] if nl != -1 else prompt
# 3) If the completion had no code, reward = 0
if idx not in compiled:
rewards.append(0.0)
continue
gen_funcs = compiled[idx]
if gen_funcs is None:
comp_fail += 1
rewards.append(0.0)
continue
# 4) Compare each generated function’s asm to the reference asm
scores = []
for _, gen_asm in gen_funcs.items():
dist = Levenshtein.distance(gen_asm, ref_asm)
norm = max(len(ref_asm), len(gen_asm)) or 1
scores.append(1.0 - dist / norm)
if not scores:
empty_funcs += 1
rewards.append(0.0)
else:
rewards.append(max(scores))
# Optional quick telemetry (will show up in your logs)
print(f"[reward_diff] n={n_comp} -> no_code={no_code}, comp_fail={comp_fail}, empty_funcs={empty_funcs}")
assert len(rewards) == n_comp
return rewards
# --- Dataset builder (reusing the same helper) -------------------------------
def build_rl_split(dataset):
rows, errors = [], 0
INSTR = (
"Decompile the following assembly code into C++. "
"Only output syntactically correct C++ code. "
"Use // comments to think out loud.\n"
)
# Make (sample_id, src) jobs for the dataset
jobs = [(i, s) for i, s in enumerate(dataset)]
# Dispatch with progress bar
compiled = dispatch_compile_and_split(jobs, max_workers=os.cpu_count(), progress=True)
for sid, res in compiled.items():
if res is None:
errors += 1
continue
for full, asm in res.items():
# full looks like "{sample_id}:{func_name}"
_, func = full.split(":", 1)
rows.append({
"sample_id": int(sid),
"func": func,
"data_source": DATA_SOURCE,
"prompt": [{"role": "user", "content": INSTR + asm}],
"ability": ABILITY,
"reward_model": {
"style": "rule",
"ground_truth": asm
},
})
if not rows:
raise RuntimeError("No functions were extracted; check the splitter regex.")
df_funcs = pd.DataFrame(rows)
return df_funcs
@app.function(image=compile_image, volumes={DATA_PATH: data_volume}, timeout=60 * 60 * 24)
def build_rl_dataset():
ds_stream = load_dataset("hytopot/code_contests_cpp", streaming=True)
raw_train = list(ds_stream["train"].take(20_000)) # adjust .take(N) as desired
raw_val = list(ds_stream["train"].take(50))
if not os.path.exists(DATA_PATH / "train.parquet"):
train_funcs = build_rl_split(raw_train)
train_funcs.to_parquet(DATA_PATH / "train.parquet")
print(f"Saved {len(train_funcs)} functions")
else:
print("Train dataset already exists")
if not os.path.exists(DATA_PATH / "val.parquet"):
val_funcs = build_rl_split(raw_val)
val_funcs.to_parquet(DATA_PATH / "val.parquet")
else:
print("Val dataset already exists")
@app.function(image=compile_image, volumes={MODELS_DIR: checkpoints_volume, DATA_PATH: data_volume})
def prefetch_model():
from huggingface_hub import snapshot_download
local_dir = "/models/Qwen2.5-Coder-7B-Instruct"
snapshot_download(
repo_id="Qwen/Qwen2.5-Coder-7B-Instruct",
local_dir=local_dir,
)
@app.function(
image=image,
gpu="A100-80GB",
volumes={
MODELS_DIR: checkpoints_volume,
DATA_PATH: data_volume,
},
secrets=[modal.Secret.from_name("wandb-secret")],
timeout=24 * 60 * 60,
)
def train(*arglist) -> None:
data_volume.reload()
cmd = [
"python", "-m", "verl.trainer.main_ppo",
"algorithm.adv_estimator=grpo",
# --- Data ---
f"data.train_files={DATA_PATH / 'train.parquet'}",
f"data.val_files={DATA_PATH / 'val.parquet'}",
"data.train_batch_size=128",
"data.max_prompt_length=1024",
"data.max_response_length=1024",
"data.filter_overlong_prompts=True",
"data.truncation=error",
# --- LoRA + model ---
"actor_rollout_ref.model.path=/models/Qwen2.5-Coder-7B-Instruct",
"actor_rollout_ref.model.lora_rank=32",
"actor_rollout_ref.model.lora_alpha=64",
"actor_rollout_ref.model.target_modules=all-linear", # LoRA requirement
"actor_rollout_ref.model.use_shm=True",
"actor_rollout_ref.model.enable_gradient_checkpointing=True",
"actor_rollout_ref.model.use_remove_padding=True", # Qwen supports it
"+actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16",
"actor_rollout_ref.rollout.dtype=bfloat16",
# --- Optim & PPO ---
"actor_rollout_ref.actor.optim.lr=3e-5",
"actor_rollout_ref.actor.ppo_mini_batch_size=32", # ↑ A100-80GB headroom
"actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4", # ↑
"actor_rollout_ref.actor.entropy_coeff=0",
"actor_rollout_ref.actor.loss_agg_mode=token-mean",
"actor_rollout_ref.actor.use_kl_loss=False",
"algorithm.use_kl_in_reward=False",
# --- FSDP (keep everything on GPU for speed on 80 GB) ---
"actor_rollout_ref.actor.fsdp_config.param_offload=False",
"actor_rollout_ref.actor.fsdp_config.optimizer_offload=False",
"actor_rollout_ref.actor.fsdp_config.forward_prefetch=True",
# --- vLLM rollout on A100-80GB ---
"actor_rollout_ref.rollout.name=vllm",
"actor_rollout_ref.rollout.gpu_memory_utilization=0.50", # more VRAM than L40S
"actor_rollout_ref.rollout.tensor_model_parallel_size=1",
"actor_rollout_ref.rollout.n=8",
"actor_rollout_ref.rollout.max_model_len=1152", # >= 512+512 w/ margin
"actor_rollout_ref.rollout.max_num_seqs=32", # ↑ concurrency
"actor_rollout_ref.rollout.max_num_batched_tokens=2048", # start here; see note
"actor_rollout_ref.rollout.load_format=safetensors", # LoRA+vLLM requirement
"actor_rollout_ref.rollout.layered_summon=True",
"actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16",
# --- Reference model: disabled (no KL) ---
# --- Trainer / infra ---
"trainer.critic_warmup=0",
"trainer.logger=['console','wandb']",
"trainer.project_name=verl_grpo_qwen2.5-coder-7b-instruct",
"trainer.experiment_name=qwen2.5-coder-7b-instruct",
"trainer.n_gpus_per_node=1",
"trainer.nnodes=1",
"trainer.test_freq=10",
f"trainer.default_local_dir={str(MODELS_DIR)}",
"trainer.resume_mode=auto",
"trainer.save_freq=10",
"trainer.total_training_steps=100",
"trainer.total_epochs=1",
# --- Custom reward ---
f"custom_reward_function.path={str(PATH_TO_REWARD_FUNCTION)}",
f"custom_reward_function.name={REWARD_FUNCTION_NAME}",
"reward_model.reward_manager='batch'",
]
if arglist:
cmd.extend(arglist)
subprocess.run(cmd, check=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment