|
#!/usr/bin/env python3 |
|
""" |
|
run_hpo.py — Hyperparameter optimization for SmolLM3-3B fine-tuning. |
|
Each Optuna trial runs in its own process to avoid MPS memory accumulation. |
|
""" |
|
|
|
import argparse |
|
import os |
|
import time |
|
import gc |
|
import torch |
|
import optuna |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments |
|
from trl import SFTTrainer |
|
from peft import LoraConfig |
|
from datasets import load_dataset |
|
|
|
|
|
# ===================================================== |
|
# ARGUMENTS |
|
# ===================================================== |
|
def parse_args(): |
|
p = argparse.ArgumentParser() |
|
p.add_argument("--batch-size", type=int, required=True) |
|
p.add_argument("--max-length", type=int, required=True) |
|
p.add_argument("--max-steps", type=int, required=False) |
|
p.add_argument("--epochs", type=float, required=False) |
|
p.add_argument("--model-name", type=str, default="HuggingFaceTB/SmolLM3-3B-Base") |
|
p.add_argument("--dataset-name", type=str, default="HuggingFaceTB/smoltalk2") |
|
p.add_argument("--dataset-split", type=str, default="OpenHermes_2.5_no_think") |
|
p.add_argument("--trials", type=int, default=10) |
|
p.add_argument("--study-db", type=str, default="sqlite:///hpo_study.db") |
|
p.add_argument("--output-dir", type=str, default="./hpo_outputs") |
|
p.add_argument( |
|
"--run-one", action="store_true", help="internal flag: run a single trial" |
|
) |
|
return p.parse_args() |
|
|
|
|
|
# ===================================================== |
|
# DATA + TOKENIZER |
|
# ===================================================== |
|
def prepare_data_and_tokenizer(model_name, dataset_name, dataset_split): |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
if tokenizer.chat_template is None: |
|
tokenizer.chat_template = ( |
|
"{% for message in messages %}\n" |
|
"{% if message['role'] == 'user' %}\n" |
|
"{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>' }}\n" |
|
"{% elif message['role'] == 'assistant' %}\n" |
|
"{{ '<|im_start|>assistant\n' + message['content'] + '<|im_end|>' }}\n" |
|
"{% endif %}\n" |
|
"{% endfor %}" |
|
) |
|
|
|
dataset_dict = load_dataset(dataset_name, "SFT") |
|
hpo_dataset = dataset_dict[dataset_split] |
|
split = hpo_dataset.train_test_split(test_size=0.2, seed=42) |
|
train_dataset = split["train"].select(range(1000)) |
|
eval_dataset = split["test"].select(range(200)) |
|
|
|
return tokenizer, train_dataset, eval_dataset |
|
|
|
|
|
# ===================================================== |
|
# OBJECTIVE FUNCTION |
|
# ===================================================== |
|
def objective_factory(args): |
|
def objective(trial: optuna.trial.Trial): |
|
tokenizer, train_dataset, eval_dataset = prepare_data_and_tokenizer( |
|
args.model_name, args.dataset_name, args.dataset_split |
|
) |
|
|
|
learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-4, log=True) |
|
num_train_epochs = ( |
|
args.epochs if args.epochs else trial.suggest_int("num_train_epochs", 1, 3) |
|
) |
|
lora_alpha = trial.suggest_int("lora_alpha", 8, 32) |
|
|
|
max_steps = args.max_steps if args.max_steps else -1 |
|
|
|
gc.collect() |
|
try: |
|
torch.mps.empty_cache() |
|
except Exception: |
|
pass |
|
|
|
print( |
|
f"\n=== Trial {trial.number}: lr={learning_rate:.2e}, epochs={num_train_epochs}, lora_alpha={lora_alpha} ===" |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
args.model_name, |
|
dtype=torch.bfloat16, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
) |
|
|
|
peft_config = LoraConfig( |
|
r=8, |
|
lora_alpha=lora_alpha, |
|
lora_dropout=0.05, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], |
|
) |
|
|
|
training_args = TrainingArguments( |
|
output_dir=os.path.join(args.output_dir, f"hpo_trial_{trial.number}"), |
|
per_device_train_batch_size=args.batch_size, |
|
per_device_eval_batch_size=args.batch_size, |
|
num_train_epochs=num_train_epochs, |
|
max_steps=max_steps, |
|
learning_rate=learning_rate, |
|
eval_strategy="no", # the early-step loss landscape sufficiently reflects which parameters are better |
|
save_strategy="no", |
|
logging_steps=50, |
|
report_to="none", |
|
bf16=True, |
|
gradient_accumulation_steps=1, |
|
dataloader_num_workers=0, |
|
) |
|
|
|
trainer = SFTTrainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
peft_config=peft_config, |
|
processing_class=tokenizer, |
|
) |
|
|
|
trainer.train() |
|
eval_loss = trainer.evaluate()["eval_loss"] |
|
|
|
del model, trainer |
|
gc.collect() |
|
try: |
|
torch.mps.empty_cache() |
|
except Exception: |
|
pass |
|
time.sleep(2) |
|
return eval_loss |
|
|
|
return objective |
|
|
|
|
|
# ===================================================== |
|
# MODES |
|
# ===================================================== |
|
def run_single_trial_mode(args): |
|
"""Child mode: run one trial.""" |
|
objective = objective_factory(args) |
|
study = optuna.load_study(study_name="hpo_study", storage=args.study_db) |
|
study.optimize(objective, n_trials=1, catch=(RuntimeError,)) |
|
|
|
|
|
def run_manager_mode(args): |
|
"""Manager mode: spawn subprocess for each trial.""" |
|
optuna.create_study( |
|
study_name="hpo_study", |
|
direction="minimize", |
|
storage=args.study_db, |
|
load_if_exists=True, |
|
) |
|
|
|
for i in range(args.trials): |
|
print(f"\n=== Starting HPO trial {i} ===") |
|
cmd = ( |
|
f"{os.sys.executable} {os.path.abspath(__file__)} " |
|
f"--batch-size {args.batch_size} " |
|
f"--epochs {args.epochs} " |
|
f"--max-steps {args.max_steps} " |
|
f"--max-length {args.max_length} " |
|
f'--model-name "{args.model_name}" ' |
|
f'--dataset-name "{args.dataset_name}" ' |
|
f'--dataset-split "{args.dataset_split}" ' |
|
f'--trials 1 --study-db "{args.study_db}" ' |
|
f'--output-dir "{args.output_dir}" --run-one' |
|
) |
|
ret = os.system(cmd) |
|
if ret != 0: |
|
print(f"Trial {i} failed (exit {ret}).") |
|
time.sleep(2) |
|
|
|
study = optuna.load_study(study_name="hpo_study", storage=args.study_db) |
|
print("\n--- Best Hyperparameters Found ---") |
|
print(study.best_params) |
|
|
|
|
|
# ===================================================== |
|
# MAIN |
|
# ===================================================== |
|
def main(): |
|
args = parse_args() |
|
if args.run_one: |
|
run_single_trial_mode(args) |
|
else: |
|
run_manager_mode(args) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |