Created
May 24, 2025 10:48
-
-
Save ehartford/5f39a278a4bd76a4d212a007c02bdbe7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# make_multi_metric_head.py | |
# ------------------------------------------------------------ | |
# Replace WorldPM-72B’s 1-unit reward head with a 15-unit head | |
# and save the result so you can fine-tune from it later. | |
# ------------------------------------------------------------ | |
import torch | |
from transformers import AutoConfig, AutoModelForSequenceClassification | |
# Metrics you want separate scores for | |
METRICS = [ | |
"structural_coherence", "pacing_rhythm", "focus_purpose", "depth_authenticity", | |
"development_evolution", "language_precision", "voice_tone", "figurative_language", | |
"dialogue_interaction", "content_depth", "argument_integration", "emotional_resonance", | |
"sustained_interest", "originality_creativity", "mechanics", "formatting_conventions" | |
] | |
BASE_MODEL = "Qwen/WorldPM-72B" # original 1-scalar checkpoint | |
OUTPUT_DIR = "WorldPM-72B-multihead" # where the new base will live | |
# 1️⃣ Load config and tell HF we now want <len(METRICS)> regression labels | |
cfg = AutoConfig.from_pretrained( | |
BASE_MODEL, | |
num_labels=len(METRICS), | |
problem_type="regression", # makes Trainer use MSELoss | |
trust_remote_code=True | |
) | |
# 2️⃣ Load model; `ignore_mismatched_sizes=True` drops the old 1-unit head | |
model = AutoModelForSequenceClassification.from_pretrained( | |
BASE_MODEL, | |
config=cfg, | |
trust_remote_code=True, | |
ignore_mismatched_sizes=True | |
) | |
# (Optional) zero-init the new head so early training is stable | |
torch.nn.init.zeros_(model.score.weight) | |
torch.nn.init.zeros_(model.score.bias) | |
# 3️⃣ Save the modified checkpoint | |
model.save_pretrained(OUTPUT_DIR) | |
cfg.save_pretrained(OUTPUT_DIR) | |
print(f"✅ Multi-metric base model saved to “{OUTPUT_DIR}”") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment