Last active
October 24, 2025 06:23
-
-
Save jazir555/9e72cced3d18743efc6dffd7919d66ba to your computer and use it in GitHub Desktop.
reap gui
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # moe_reap_gui.py –– REAP GUI driver (FIXED for actual REAP interface) | |
| # --------------------------------------------------------------------------- | |
| import os | |
| import sys | |
| import json | |
| import threading | |
| import subprocess | |
| import tkinter as tk | |
| from tkinter import ttk, filedialog, messagebox | |
| from pathlib import Path | |
| import queue | |
| import re | |
| REAP_DIR = Path.home() / ".cache" / "reap" | |
| # --------------------------------------------------------------------------- | |
| # 1. Helpers | |
| # --------------------------------------------------------------------------- | |
| def ensure_reap(): | |
| """Ensure REAP repo exists and is pip-installed.""" | |
| if not REAP_DIR.exists(): | |
| REAP_DIR.parent.mkdir(parents=True, exist_ok=True) | |
| subprocess.run( | |
| ["git", "clone", "--depth", "1", | |
| "https://github.com/CerebrasResearch/reap", str(REAP_DIR)], | |
| check=True) | |
| # Build using their script | |
| build_script = REAP_DIR / "scripts" / "build.sh" | |
| if build_script.exists(): | |
| subprocess.run(["bash", str(build_script)], cwd=REAP_DIR, check=True) | |
| else: | |
| # Fallback to pip install | |
| subprocess.run( | |
| [sys.executable, "-m", "pip", "install", "-e", str(REAP_DIR)], | |
| check=True) | |
| # --------------------------------------------------------------------------- | |
| # 2. Back-end | |
| # --------------------------------------------------------------------------- | |
| class ReapDriver: | |
| """Handles REAP bash script execution.""" | |
| def __init__(self, model_id, pruning_method, kill_ratio, out_dir, | |
| dataset_name="theblackcat102/evol-codealpaca-v1", | |
| cuda_devices="0", seed=42, | |
| run_lm_eval=False, run_evalplus=False, | |
| run_live_code=False, run_math=False, run_wildbench=False, | |
| cancel_event=None, log_q=None): | |
| self.model_id = model_id | |
| self.pruning_method = pruning_method | |
| self.kill_ratio = float(kill_ratio) | |
| self.out_dir = Path(out_dir) | |
| self.dataset_name = dataset_name | |
| self.cuda_devices = str(cuda_devices) | |
| self.seed = int(seed) | |
| self.run_lm_eval = run_lm_eval | |
| self.run_evalplus = run_evalplus | |
| self.run_live_code = run_live_code | |
| self.run_math = run_math | |
| self.run_wildbench = run_wildbench | |
| self.cancel_event = cancel_event or threading.Event() | |
| self.log_q = log_q or queue.Queue() | |
| def run(self, progress_cb): | |
| """Run the REAP pruning pipeline via bash script.""" | |
| ensure_reap() | |
| pruning_script = REAP_DIR / "experiments" / "pruning-cli.sh" | |
| if not pruning_script.exists(): | |
| raise FileNotFoundError(f"Cannot find {pruning_script}") | |
| # Build command matching the actual REAP interface | |
| cmd = [ | |
| "bash", | |
| str(pruning_script), | |
| self.cuda_devices, | |
| self.model_id, | |
| self.pruning_method, | |
| str(self.seed), | |
| str(self.kill_ratio), | |
| self.dataset_name, | |
| "true" if self.run_lm_eval else "false", | |
| "true" if self.run_evalplus else "false", | |
| "true" if self.run_live_code else "false", | |
| "true" if self.run_math else "false", | |
| "true" if self.run_wildbench else "false", | |
| ] | |
| self._log(f"▶ Starting REAP pruning with method: {self.pruning_method}", level="info") | |
| self._log(f" Model: {self.model_id}", level="info") | |
| self._log(f" Compression: {self.kill_ratio*100:.0f}%", level="info") | |
| self._log(f" Dataset: {self.dataset_name}", level="info") | |
| proc = subprocess.Popen( | |
| cmd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| bufsize=1, | |
| universal_newlines=True, | |
| cwd=REAP_DIR, | |
| env={**os.environ, "CUDA_VISIBLE_DEVICES": self.cuda_devices} | |
| ) | |
| try: | |
| # Track progress through different stages | |
| stage_progress = { | |
| "loading": (0, 10), | |
| "calibrating": (10, 40), | |
| "scoring": (40, 60), | |
| "pruning": (60, 80), | |
| "evaluating": (80, 95), | |
| "saving": (95, 100) | |
| } | |
| current_stage = "loading" | |
| for raw_line in proc.stdout: | |
| if self.cancel_event.is_set(): | |
| try: | |
| proc.terminate() | |
| proc.wait(timeout=5) | |
| except Exception: | |
| proc.kill() | |
| self._log("Operation cancelled by user.", level="warn") | |
| raise RuntimeError("Operation cancelled.") | |
| line = raw_line.rstrip("\n") | |
| # Detect stage transitions from log output | |
| line_lower = line.lower() | |
| if "loading model" in line_lower or "downloading" in line_lower: | |
| current_stage = "loading" | |
| elif "calibrat" in line_lower or "collecting activation" in line_lower: | |
| current_stage = "calibrating" | |
| elif "comput" in line_lower and "score" in line_lower: | |
| current_stage = "scoring" | |
| elif "prun" in line_lower or "remov" in line_lower: | |
| current_stage = "pruning" | |
| elif "evaluat" in line_lower or "benchmark" in line_lower: | |
| current_stage = "evaluating" | |
| elif "saving" in line_lower or "writing" in line_lower: | |
| current_stage = "saving" | |
| # Update progress based on current stage | |
| start_pct, end_pct = stage_progress.get(current_stage, (50, 60)) | |
| # Look for percentage indicators in output | |
| pct_match = re.search(r'(\d+)%', line) | |
| if pct_match: | |
| stage_pct = float(pct_match.group(1)) | |
| overall_pct = start_pct + (end_pct - start_pct) * (stage_pct / 100.0) | |
| progress_cb(overall_pct) | |
| self._log(f"[{current_stage}] {stage_pct:.0f}% -> overall {overall_pct:.1f}%", | |
| level="progress") | |
| else: | |
| # Just show we're in this stage | |
| mid_pct = (start_pct + end_pct) / 2 | |
| progress_cb(mid_pct) | |
| # Log the line with appropriate classification | |
| self._classify_and_log(line) | |
| proc.wait() | |
| if proc.returncode != 0: | |
| self._log(f"REAP pruning failed with exit code {proc.returncode}", level="error") | |
| raise RuntimeError(f"Pruning failed (exit {proc.returncode})") | |
| # Look for the output model directory | |
| self._log("✓ Pruning complete!", level="info") | |
| self._find_output_model() | |
| return {"status": "success", "returncode": proc.returncode} | |
| finally: | |
| try: | |
| proc.stdout.close() | |
| except Exception: | |
| pass | |
| def _find_output_model(self): | |
| """Try to locate where REAP saved the pruned model.""" | |
| # REAP typically saves to experiments/results/<model_name>_<method>_<ratio>/ | |
| results_dir = REAP_DIR / "experiments" / "results" | |
| if results_dir.exists(): | |
| # Find most recent directory | |
| subdirs = sorted(results_dir.glob("*/"), key=lambda p: p.stat().st_mtime, reverse=True) | |
| if subdirs: | |
| latest = subdirs[0] | |
| self._log(f"✓ Pruned model saved to: {latest}", level="info") | |
| return latest | |
| self._log("Note: Check experiments/results/ for output model", level="warn") | |
| def _classify_and_log(self, line: str): | |
| """Heuristic classification for plain text lines.""" | |
| lower = line.lower() | |
| if "error" in lower or "failed" in lower or "exception" in lower or "traceback" in lower: | |
| self._log(line, level="error") | |
| elif "warn" in lower or "warning" in lower: | |
| self._log(line, level="warn") | |
| elif "✓" in line or "success" in lower or "complete" in lower: | |
| self._log(line, level="info") | |
| elif any(x in lower for x in ["progress", "%", "step", "iter", "epoch"]): | |
| self._log(line, level="progress") | |
| else: | |
| self._log(line, level="info") | |
| def _log(self, text, level="info"): | |
| """Push a (text, level) tuple to the GUI log queue.""" | |
| try: | |
| payload = json.dumps({"level": level, "text": text}) | |
| self.log_q.put_nowait(payload) | |
| except Exception: | |
| pass | |
| # --------------------------------------------------------------------------- | |
| # 3. Tkinter GUI | |
| # --------------------------------------------------------------------------- | |
| class App: | |
| def __init__(self, root): | |
| self.root = root | |
| root.title("REAP – Expert Pruning for MoE Models") | |
| # Model configuration | |
| self.model = tk.StringVar(value="Qwen/Qwen3-30B-A3B") | |
| self.method = tk.StringVar(value="reap") | |
| self.ratio = tk.DoubleVar(value=0.25) | |
| self.dataset = tk.StringVar(value="theblackcat102/evol-codealpaca-v1") | |
| self.cuda_dev = tk.StringVar(value="0") | |
| self.seed = tk.IntVar(value=42) | |
| # Evaluation flags | |
| self.run_lm_eval = tk.BooleanVar(value=False) | |
| self.run_evalplus = tk.BooleanVar(value=False) | |
| self.run_live_code = tk.BooleanVar(value=False) | |
| self.run_math = tk.BooleanVar(value=False) | |
| self.run_wildbench = tk.BooleanVar(value=False) | |
| self.cancel_event = threading.Event() | |
| self.log_q = queue.Queue() | |
| self.build_ui() | |
| self.root.after(200, self._poll_logs) | |
| def build_ui(self): | |
| pad = dict(padx=5, pady=5) | |
| # Main configuration section | |
| config_frame = ttk.LabelFrame(self.root, text="Configuration", padding=10) | |
| config_frame.grid(row=0, column=0, sticky="ew", **pad) | |
| ttk.Label(config_frame, text="Model (HF id):").grid(row=0, column=0, sticky="w", **pad) | |
| ttk.Entry(config_frame, textvariable=self.model, width=50).grid(row=0, column=1, columnspan=2, sticky="ew", **pad) | |
| ttk.Label(config_frame, text="Pruning Method:").grid(row=1, column=0, sticky="w", **pad) | |
| method_combo = ttk.Combobox(config_frame, textvariable=self.method, | |
| values=["reap", "frequency", "ean"], state="readonly", width=20) | |
| method_combo.grid(row=1, column=1, sticky="w", **pad) | |
| ttk.Label(config_frame, text="Compression Ratio:").grid(row=2, column=0, sticky="w", **pad) | |
| ratio_frame = ttk.Frame(config_frame) | |
| ratio_frame.grid(row=2, column=1, columnspan=2, sticky="ew", **pad) | |
| ttk.Scale(ratio_frame, from_=0.05, to=0.95, variable=self.ratio, | |
| orient="horizontal").pack(side="left", fill="x", expand=True) | |
| ratio_label = ttk.Label(ratio_frame, text="25%") | |
| ratio_label.pack(side="right", padx=5) | |
| self.ratio.trace_add("write", lambda *_: ratio_label.config( | |
| text=f"{self.ratio.get()*100:.0f}%")) | |
| ttk.Label(config_frame, text="Dataset:").grid(row=3, column=0, sticky="w", **pad) | |
| dataset_combo = ttk.Combobox(config_frame, textvariable=self.dataset, width=50, | |
| values=[ | |
| "theblackcat102/evol-codealpaca-v1", | |
| "allenai/c4", | |
| "allenai/tulu-3-sft-personas-math", | |
| "euclaise/WritingPrompts_curated" | |
| ]) | |
| dataset_combo.grid(row=3, column=1, columnspan=2, sticky="ew", **pad) | |
| ttk.Label(config_frame, text="CUDA Device(s):").grid(row=4, column=0, sticky="w", **pad) | |
| ttk.Entry(config_frame, textvariable=self.cuda_dev, width=10).grid(row=4, column=1, sticky="w", **pad) | |
| ttk.Label(config_frame, text="Random Seed:").grid(row=4, column=2, sticky="w", **pad) | |
| ttk.Entry(config_frame, textvariable=self.seed, width=10).grid(row=4, column=2, sticky="e", **pad) | |
| config_frame.grid_columnconfigure(1, weight=1) | |
| # Evaluation options | |
| eval_frame = ttk.LabelFrame(self.root, text="Evaluation Options (optional)", padding=10) | |
| eval_frame.grid(row=1, column=0, sticky="ew", **pad) | |
| ttk.Checkbutton(eval_frame, text="LM-Eval (MMLU, etc.)", | |
| variable=self.run_lm_eval).grid(row=0, column=0, sticky="w", **pad) | |
| ttk.Checkbutton(eval_frame, text="EvalPlus (HumanEval+)", | |
| variable=self.run_evalplus).grid(row=0, column=1, sticky="w", **pad) | |
| ttk.Checkbutton(eval_frame, text="LiveCodeBench", | |
| variable=self.run_live_code).grid(row=0, column=2, sticky="w", **pad) | |
| ttk.Checkbutton(eval_frame, text="Math (GSM8K)", | |
| variable=self.run_math).grid(row=1, column=0, sticky="w", **pad) | |
| ttk.Checkbutton(eval_frame, text="WildBench", | |
| variable=self.run_wildbench).grid(row=1, column=1, sticky="w", **pad) | |
| # Progress section | |
| progress_frame = ttk.Frame(self.root) | |
| progress_frame.grid(row=2, column=0, sticky="ew", **pad) | |
| self.bar = ttk.Progressbar(progress_frame, orient="horizontal", | |
| mode="determinate", length=400) | |
| self.bar.pack(fill="x", expand=True, pady=5) | |
| # Control buttons | |
| btn_frame = ttk.Frame(self.root) | |
| btn_frame.grid(row=3, column=0, pady=5) | |
| self.run_btn = ttk.Button(btn_frame, text="▶ Run Pruning", command=self.start) | |
| self.run_btn.grid(row=0, column=0, padx=5) | |
| self.cancel_btn = ttk.Button(btn_frame, text="✕ Cancel", | |
| command=self.cancel, state="disabled") | |
| self.cancel_btn.grid(row=0, column=1, padx=5) | |
| ttk.Button(btn_frame, text="Exit", command=self.root.quit).grid(row=0, column=2, padx=5) | |
| # Logging section | |
| log_frame = ttk.LabelFrame(self.root, text="Logs", padding=5) | |
| log_frame.grid(row=4, column=0, sticky="nsew", **pad) | |
| self.log_text = tk.Text(log_frame, height=20, width=100, wrap="word", | |
| state="disabled", bg="#0b0b0b", fg="#d0ffd0") | |
| scrollbar = ttk.Scrollbar(log_frame, command=self.log_text.yview) | |
| self.log_text.configure(yscrollcommand=scrollbar.set) | |
| self.log_text.pack(side="left", fill="both", expand=True) | |
| scrollbar.pack(side="right", fill="y") | |
| # Tag configurations for color-coded levels | |
| self.log_text.tag_configure("info", foreground="#a6ffa6") | |
| self.log_text.tag_configure("progress", foreground="#6fe3ff") | |
| self.log_text.tag_configure("warn", foreground="#ffb86b") | |
| self.log_text.tag_configure("error", foreground="#ff6b6b") | |
| # Log control buttons | |
| log_btn_frame = ttk.Frame(self.root) | |
| log_btn_frame.grid(row=5, column=0, sticky="e", padx=5, pady=(0,5)) | |
| ttk.Button(log_btn_frame, text="Clear Logs", | |
| command=self.clear_logs).grid(row=0, column=0, padx=2) | |
| ttk.Button(log_btn_frame, text="Save Logs...", | |
| command=self.save_logs).grid(row=0, column=1, padx=2) | |
| # Configure grid weights for resizing | |
| self.root.grid_rowconfigure(4, weight=1) | |
| self.root.grid_columnconfigure(0, weight=1) | |
| def start(self): | |
| """Start the pruning process.""" | |
| if not self.model.get().strip(): | |
| messagebox.showerror("Error", "Please enter a model name") | |
| return | |
| self.run_btn.config(state="disabled") | |
| self.cancel_btn.config(state="normal") | |
| self.bar["value"] = 0 | |
| self.clear_logs() | |
| self.cancel_event.clear() | |
| threading.Thread(target=self.worker, daemon=True).start() | |
| def cancel(self): | |
| """Request cancellation of the current operation.""" | |
| self.cancel_event.set() | |
| self._append_log("⚠️ Cancel requested by user\n", level="warn") | |
| def worker(self): | |
| """Background thread that runs the REAP driver.""" | |
| try: | |
| driver = ReapDriver( | |
| model_id=self.model.get(), | |
| pruning_method=self.method.get(), | |
| kill_ratio=self.ratio.get(), | |
| out_dir=REAP_DIR / "experiments" / "results", | |
| dataset_name=self.dataset.get(), | |
| cuda_devices=self.cuda_dev.get(), | |
| seed=self.seed.get(), | |
| run_lm_eval=self.run_lm_eval.get(), | |
| run_evalplus=self.run_evalplus.get(), | |
| run_live_code=self.run_live_code.get(), | |
| run_math=self.run_math.get(), | |
| run_wildbench=self.run_wildbench.get(), | |
| cancel_event=self.cancel_event, | |
| log_q=self.log_q | |
| ) | |
| result = driver.run(lambda p: self.root.after(0, lambda: self.bar.config(value=p))) | |
| if not self.cancel_event.is_set(): | |
| self.root.after(0, lambda: messagebox.showinfo( | |
| "Success", | |
| "Pruning complete! Check experiments/results/ for output model.")) | |
| except Exception as e: | |
| msg = str(e) | |
| if "cancel" not in msg.lower(): | |
| self._append_log(f"ERROR: {msg}\n", level="error") | |
| self.root.after(0, lambda: messagebox.showerror("Error", msg)) | |
| finally: | |
| self.root.after(0, self._reset_ui) | |
| def _poll_logs(self): | |
| """Poll the log queue and update the text widget.""" | |
| try: | |
| while True: | |
| payload = self.log_q.get_nowait() | |
| try: | |
| item = json.loads(payload) | |
| level = item.get("level", "info") | |
| text = item.get("text", "") | |
| except Exception: | |
| level = "info" | |
| text = str(payload) | |
| self.root.after(0, lambda t=text, l=level: self._append_log(t + "\n", level=l)) | |
| except queue.Empty: | |
| pass | |
| self.root.after(150, self._poll_logs) | |
| def _append_log(self, text, level="info"): | |
| """Append colored text to the log widget.""" | |
| self.log_text.configure(state="normal") | |
| tag = level if level in ("info", "progress", "warn", "error") else "info" | |
| self.log_text.insert("end", text, (tag,)) | |
| self.log_text.see("end") | |
| self.log_text.configure(state="disabled") | |
| def clear_logs(self): | |
| """Clear all log text.""" | |
| self.log_text.configure(state="normal") | |
| self.log_text.delete("1.0", "end") | |
| self.log_text.configure(state="disabled") | |
| def save_logs(self): | |
| """Save logs to a file.""" | |
| f = filedialog.asksaveasfilename( | |
| defaultextension=".log", | |
| filetypes=[("Log files","*.log"),("All files","*.*")]) | |
| if not f: | |
| return | |
| try: | |
| with open(f, "w", encoding="utf-8") as fh: | |
| text = self.log_text.get("1.0", "end") | |
| fh.write(text) | |
| messagebox.showinfo("Saved", f"Logs saved to {f}") | |
| except Exception as e: | |
| messagebox.showerror("Error", f"Failed to save logs: {e}") | |
| def _reset_ui(self): | |
| """Reset UI state after operation completes.""" | |
| self.run_btn.config(state="normal") | |
| self.cancel_btn.config(state="disabled") | |
| # --------------------------------------------------------------------------- | |
| # 4. Entry point | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| root = tk.Tk() | |
| App(root) | |
| root.mainloop() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment