Skip to content

Instantly share code, notes, and snippets.

@jazir555
Last active October 24, 2025 06:23
Show Gist options
  • Save jazir555/9e72cced3d18743efc6dffd7919d66ba to your computer and use it in GitHub Desktop.
Save jazir555/9e72cced3d18743efc6dffd7919d66ba to your computer and use it in GitHub Desktop.
reap gui
#!/usr/bin/env python3
# moe_reap_gui.py –– REAP GUI driver (FIXED for actual REAP interface)
# ---------------------------------------------------------------------------
import os
import sys
import json
import threading
import subprocess
import tkinter as tk
from tkinter import ttk, filedialog, messagebox
from pathlib import Path
import queue
import re
REAP_DIR = Path.home() / ".cache" / "reap"
# ---------------------------------------------------------------------------
# 1. Helpers
# ---------------------------------------------------------------------------
def ensure_reap():
"""Ensure REAP repo exists and is pip-installed."""
if not REAP_DIR.exists():
REAP_DIR.parent.mkdir(parents=True, exist_ok=True)
subprocess.run(
["git", "clone", "--depth", "1",
"https://github.com/CerebrasResearch/reap", str(REAP_DIR)],
check=True)
# Build using their script
build_script = REAP_DIR / "scripts" / "build.sh"
if build_script.exists():
subprocess.run(["bash", str(build_script)], cwd=REAP_DIR, check=True)
else:
# Fallback to pip install
subprocess.run(
[sys.executable, "-m", "pip", "install", "-e", str(REAP_DIR)],
check=True)
# ---------------------------------------------------------------------------
# 2. Back-end
# ---------------------------------------------------------------------------
class ReapDriver:
"""Handles REAP bash script execution."""
def __init__(self, model_id, pruning_method, kill_ratio, out_dir,
dataset_name="theblackcat102/evol-codealpaca-v1",
cuda_devices="0", seed=42,
run_lm_eval=False, run_evalplus=False,
run_live_code=False, run_math=False, run_wildbench=False,
cancel_event=None, log_q=None):
self.model_id = model_id
self.pruning_method = pruning_method
self.kill_ratio = float(kill_ratio)
self.out_dir = Path(out_dir)
self.dataset_name = dataset_name
self.cuda_devices = str(cuda_devices)
self.seed = int(seed)
self.run_lm_eval = run_lm_eval
self.run_evalplus = run_evalplus
self.run_live_code = run_live_code
self.run_math = run_math
self.run_wildbench = run_wildbench
self.cancel_event = cancel_event or threading.Event()
self.log_q = log_q or queue.Queue()
def run(self, progress_cb):
"""Run the REAP pruning pipeline via bash script."""
ensure_reap()
pruning_script = REAP_DIR / "experiments" / "pruning-cli.sh"
if not pruning_script.exists():
raise FileNotFoundError(f"Cannot find {pruning_script}")
# Build command matching the actual REAP interface
cmd = [
"bash",
str(pruning_script),
self.cuda_devices,
self.model_id,
self.pruning_method,
str(self.seed),
str(self.kill_ratio),
self.dataset_name,
"true" if self.run_lm_eval else "false",
"true" if self.run_evalplus else "false",
"true" if self.run_live_code else "false",
"true" if self.run_math else "false",
"true" if self.run_wildbench else "false",
]
self._log(f"▶ Starting REAP pruning with method: {self.pruning_method}", level="info")
self._log(f" Model: {self.model_id}", level="info")
self._log(f" Compression: {self.kill_ratio*100:.0f}%", level="info")
self._log(f" Dataset: {self.dataset_name}", level="info")
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True,
cwd=REAP_DIR,
env={**os.environ, "CUDA_VISIBLE_DEVICES": self.cuda_devices}
)
try:
# Track progress through different stages
stage_progress = {
"loading": (0, 10),
"calibrating": (10, 40),
"scoring": (40, 60),
"pruning": (60, 80),
"evaluating": (80, 95),
"saving": (95, 100)
}
current_stage = "loading"
for raw_line in proc.stdout:
if self.cancel_event.is_set():
try:
proc.terminate()
proc.wait(timeout=5)
except Exception:
proc.kill()
self._log("Operation cancelled by user.", level="warn")
raise RuntimeError("Operation cancelled.")
line = raw_line.rstrip("\n")
# Detect stage transitions from log output
line_lower = line.lower()
if "loading model" in line_lower or "downloading" in line_lower:
current_stage = "loading"
elif "calibrat" in line_lower or "collecting activation" in line_lower:
current_stage = "calibrating"
elif "comput" in line_lower and "score" in line_lower:
current_stage = "scoring"
elif "prun" in line_lower or "remov" in line_lower:
current_stage = "pruning"
elif "evaluat" in line_lower or "benchmark" in line_lower:
current_stage = "evaluating"
elif "saving" in line_lower or "writing" in line_lower:
current_stage = "saving"
# Update progress based on current stage
start_pct, end_pct = stage_progress.get(current_stage, (50, 60))
# Look for percentage indicators in output
pct_match = re.search(r'(\d+)%', line)
if pct_match:
stage_pct = float(pct_match.group(1))
overall_pct = start_pct + (end_pct - start_pct) * (stage_pct / 100.0)
progress_cb(overall_pct)
self._log(f"[{current_stage}] {stage_pct:.0f}% -> overall {overall_pct:.1f}%",
level="progress")
else:
# Just show we're in this stage
mid_pct = (start_pct + end_pct) / 2
progress_cb(mid_pct)
# Log the line with appropriate classification
self._classify_and_log(line)
proc.wait()
if proc.returncode != 0:
self._log(f"REAP pruning failed with exit code {proc.returncode}", level="error")
raise RuntimeError(f"Pruning failed (exit {proc.returncode})")
# Look for the output model directory
self._log("✓ Pruning complete!", level="info")
self._find_output_model()
return {"status": "success", "returncode": proc.returncode}
finally:
try:
proc.stdout.close()
except Exception:
pass
def _find_output_model(self):
"""Try to locate where REAP saved the pruned model."""
# REAP typically saves to experiments/results/<model_name>_<method>_<ratio>/
results_dir = REAP_DIR / "experiments" / "results"
if results_dir.exists():
# Find most recent directory
subdirs = sorted(results_dir.glob("*/"), key=lambda p: p.stat().st_mtime, reverse=True)
if subdirs:
latest = subdirs[0]
self._log(f"✓ Pruned model saved to: {latest}", level="info")
return latest
self._log("Note: Check experiments/results/ for output model", level="warn")
def _classify_and_log(self, line: str):
"""Heuristic classification for plain text lines."""
lower = line.lower()
if "error" in lower or "failed" in lower or "exception" in lower or "traceback" in lower:
self._log(line, level="error")
elif "warn" in lower or "warning" in lower:
self._log(line, level="warn")
elif "✓" in line or "success" in lower or "complete" in lower:
self._log(line, level="info")
elif any(x in lower for x in ["progress", "%", "step", "iter", "epoch"]):
self._log(line, level="progress")
else:
self._log(line, level="info")
def _log(self, text, level="info"):
"""Push a (text, level) tuple to the GUI log queue."""
try:
payload = json.dumps({"level": level, "text": text})
self.log_q.put_nowait(payload)
except Exception:
pass
# ---------------------------------------------------------------------------
# 3. Tkinter GUI
# ---------------------------------------------------------------------------
class App:
def __init__(self, root):
self.root = root
root.title("REAP – Expert Pruning for MoE Models")
# Model configuration
self.model = tk.StringVar(value="Qwen/Qwen3-30B-A3B")
self.method = tk.StringVar(value="reap")
self.ratio = tk.DoubleVar(value=0.25)
self.dataset = tk.StringVar(value="theblackcat102/evol-codealpaca-v1")
self.cuda_dev = tk.StringVar(value="0")
self.seed = tk.IntVar(value=42)
# Evaluation flags
self.run_lm_eval = tk.BooleanVar(value=False)
self.run_evalplus = tk.BooleanVar(value=False)
self.run_live_code = tk.BooleanVar(value=False)
self.run_math = tk.BooleanVar(value=False)
self.run_wildbench = tk.BooleanVar(value=False)
self.cancel_event = threading.Event()
self.log_q = queue.Queue()
self.build_ui()
self.root.after(200, self._poll_logs)
def build_ui(self):
pad = dict(padx=5, pady=5)
# Main configuration section
config_frame = ttk.LabelFrame(self.root, text="Configuration", padding=10)
config_frame.grid(row=0, column=0, sticky="ew", **pad)
ttk.Label(config_frame, text="Model (HF id):").grid(row=0, column=0, sticky="w", **pad)
ttk.Entry(config_frame, textvariable=self.model, width=50).grid(row=0, column=1, columnspan=2, sticky="ew", **pad)
ttk.Label(config_frame, text="Pruning Method:").grid(row=1, column=0, sticky="w", **pad)
method_combo = ttk.Combobox(config_frame, textvariable=self.method,
values=["reap", "frequency", "ean"], state="readonly", width=20)
method_combo.grid(row=1, column=1, sticky="w", **pad)
ttk.Label(config_frame, text="Compression Ratio:").grid(row=2, column=0, sticky="w", **pad)
ratio_frame = ttk.Frame(config_frame)
ratio_frame.grid(row=2, column=1, columnspan=2, sticky="ew", **pad)
ttk.Scale(ratio_frame, from_=0.05, to=0.95, variable=self.ratio,
orient="horizontal").pack(side="left", fill="x", expand=True)
ratio_label = ttk.Label(ratio_frame, text="25%")
ratio_label.pack(side="right", padx=5)
self.ratio.trace_add("write", lambda *_: ratio_label.config(
text=f"{self.ratio.get()*100:.0f}%"))
ttk.Label(config_frame, text="Dataset:").grid(row=3, column=0, sticky="w", **pad)
dataset_combo = ttk.Combobox(config_frame, textvariable=self.dataset, width=50,
values=[
"theblackcat102/evol-codealpaca-v1",
"allenai/c4",
"allenai/tulu-3-sft-personas-math",
"euclaise/WritingPrompts_curated"
])
dataset_combo.grid(row=3, column=1, columnspan=2, sticky="ew", **pad)
ttk.Label(config_frame, text="CUDA Device(s):").grid(row=4, column=0, sticky="w", **pad)
ttk.Entry(config_frame, textvariable=self.cuda_dev, width=10).grid(row=4, column=1, sticky="w", **pad)
ttk.Label(config_frame, text="Random Seed:").grid(row=4, column=2, sticky="w", **pad)
ttk.Entry(config_frame, textvariable=self.seed, width=10).grid(row=4, column=2, sticky="e", **pad)
config_frame.grid_columnconfigure(1, weight=1)
# Evaluation options
eval_frame = ttk.LabelFrame(self.root, text="Evaluation Options (optional)", padding=10)
eval_frame.grid(row=1, column=0, sticky="ew", **pad)
ttk.Checkbutton(eval_frame, text="LM-Eval (MMLU, etc.)",
variable=self.run_lm_eval).grid(row=0, column=0, sticky="w", **pad)
ttk.Checkbutton(eval_frame, text="EvalPlus (HumanEval+)",
variable=self.run_evalplus).grid(row=0, column=1, sticky="w", **pad)
ttk.Checkbutton(eval_frame, text="LiveCodeBench",
variable=self.run_live_code).grid(row=0, column=2, sticky="w", **pad)
ttk.Checkbutton(eval_frame, text="Math (GSM8K)",
variable=self.run_math).grid(row=1, column=0, sticky="w", **pad)
ttk.Checkbutton(eval_frame, text="WildBench",
variable=self.run_wildbench).grid(row=1, column=1, sticky="w", **pad)
# Progress section
progress_frame = ttk.Frame(self.root)
progress_frame.grid(row=2, column=0, sticky="ew", **pad)
self.bar = ttk.Progressbar(progress_frame, orient="horizontal",
mode="determinate", length=400)
self.bar.pack(fill="x", expand=True, pady=5)
# Control buttons
btn_frame = ttk.Frame(self.root)
btn_frame.grid(row=3, column=0, pady=5)
self.run_btn = ttk.Button(btn_frame, text="▶ Run Pruning", command=self.start)
self.run_btn.grid(row=0, column=0, padx=5)
self.cancel_btn = ttk.Button(btn_frame, text="✕ Cancel",
command=self.cancel, state="disabled")
self.cancel_btn.grid(row=0, column=1, padx=5)
ttk.Button(btn_frame, text="Exit", command=self.root.quit).grid(row=0, column=2, padx=5)
# Logging section
log_frame = ttk.LabelFrame(self.root, text="Logs", padding=5)
log_frame.grid(row=4, column=0, sticky="nsew", **pad)
self.log_text = tk.Text(log_frame, height=20, width=100, wrap="word",
state="disabled", bg="#0b0b0b", fg="#d0ffd0")
scrollbar = ttk.Scrollbar(log_frame, command=self.log_text.yview)
self.log_text.configure(yscrollcommand=scrollbar.set)
self.log_text.pack(side="left", fill="both", expand=True)
scrollbar.pack(side="right", fill="y")
# Tag configurations for color-coded levels
self.log_text.tag_configure("info", foreground="#a6ffa6")
self.log_text.tag_configure("progress", foreground="#6fe3ff")
self.log_text.tag_configure("warn", foreground="#ffb86b")
self.log_text.tag_configure("error", foreground="#ff6b6b")
# Log control buttons
log_btn_frame = ttk.Frame(self.root)
log_btn_frame.grid(row=5, column=0, sticky="e", padx=5, pady=(0,5))
ttk.Button(log_btn_frame, text="Clear Logs",
command=self.clear_logs).grid(row=0, column=0, padx=2)
ttk.Button(log_btn_frame, text="Save Logs...",
command=self.save_logs).grid(row=0, column=1, padx=2)
# Configure grid weights for resizing
self.root.grid_rowconfigure(4, weight=1)
self.root.grid_columnconfigure(0, weight=1)
def start(self):
"""Start the pruning process."""
if not self.model.get().strip():
messagebox.showerror("Error", "Please enter a model name")
return
self.run_btn.config(state="disabled")
self.cancel_btn.config(state="normal")
self.bar["value"] = 0
self.clear_logs()
self.cancel_event.clear()
threading.Thread(target=self.worker, daemon=True).start()
def cancel(self):
"""Request cancellation of the current operation."""
self.cancel_event.set()
self._append_log("⚠️ Cancel requested by user\n", level="warn")
def worker(self):
"""Background thread that runs the REAP driver."""
try:
driver = ReapDriver(
model_id=self.model.get(),
pruning_method=self.method.get(),
kill_ratio=self.ratio.get(),
out_dir=REAP_DIR / "experiments" / "results",
dataset_name=self.dataset.get(),
cuda_devices=self.cuda_dev.get(),
seed=self.seed.get(),
run_lm_eval=self.run_lm_eval.get(),
run_evalplus=self.run_evalplus.get(),
run_live_code=self.run_live_code.get(),
run_math=self.run_math.get(),
run_wildbench=self.run_wildbench.get(),
cancel_event=self.cancel_event,
log_q=self.log_q
)
result = driver.run(lambda p: self.root.after(0, lambda: self.bar.config(value=p)))
if not self.cancel_event.is_set():
self.root.after(0, lambda: messagebox.showinfo(
"Success",
"Pruning complete! Check experiments/results/ for output model."))
except Exception as e:
msg = str(e)
if "cancel" not in msg.lower():
self._append_log(f"ERROR: {msg}\n", level="error")
self.root.after(0, lambda: messagebox.showerror("Error", msg))
finally:
self.root.after(0, self._reset_ui)
def _poll_logs(self):
"""Poll the log queue and update the text widget."""
try:
while True:
payload = self.log_q.get_nowait()
try:
item = json.loads(payload)
level = item.get("level", "info")
text = item.get("text", "")
except Exception:
level = "info"
text = str(payload)
self.root.after(0, lambda t=text, l=level: self._append_log(t + "\n", level=l))
except queue.Empty:
pass
self.root.after(150, self._poll_logs)
def _append_log(self, text, level="info"):
"""Append colored text to the log widget."""
self.log_text.configure(state="normal")
tag = level if level in ("info", "progress", "warn", "error") else "info"
self.log_text.insert("end", text, (tag,))
self.log_text.see("end")
self.log_text.configure(state="disabled")
def clear_logs(self):
"""Clear all log text."""
self.log_text.configure(state="normal")
self.log_text.delete("1.0", "end")
self.log_text.configure(state="disabled")
def save_logs(self):
"""Save logs to a file."""
f = filedialog.asksaveasfilename(
defaultextension=".log",
filetypes=[("Log files","*.log"),("All files","*.*")])
if not f:
return
try:
with open(f, "w", encoding="utf-8") as fh:
text = self.log_text.get("1.0", "end")
fh.write(text)
messagebox.showinfo("Saved", f"Logs saved to {f}")
except Exception as e:
messagebox.showerror("Error", f"Failed to save logs: {e}")
def _reset_ui(self):
"""Reset UI state after operation completes."""
self.run_btn.config(state="normal")
self.cancel_btn.config(state="disabled")
# ---------------------------------------------------------------------------
# 4. Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
root = tk.Tk()
App(root)
root.mainloop()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment