Skip to content

Instantly share code, notes, and snippets.

@robbiemu
Created October 17, 2025 12:33
Show Gist options
  • Save robbiemu/e8c62ad92c0743c7214c8de40f3a5d1b to your computer and use it in GitHub Desktop.
Save robbiemu/e8c62ad92c0743c7214c8de40f3a5d1b to your computer and use it in GitHub Desktop.
HPO and batch size (for exercise 3 of https://huggingface.co/learn/smol-course unit 1)
import os
print("Finding absolute max_length from formatted_dataset...")
# Calculate the token length for each sample in your prepared dataset
token_lengths = [len(tokenizer(x["text"]).input_ids) for x in formatted_dataset]
# Find the length of the single longest sample
max_length = max(token_lengths)
print(f"Absolute max_length found: {max_length}")
# --- Save the dataset for the external test scripts to use ---
final_data_path = os.path.abspath("./final_training_data")
formatted_dataset.save_to_disk(final_data_path)
print(f"✅ Saved final formatted dataset to '{final_data_path}'")
import subprocess
import sys
import pandas as pd
import os
# This ensures the TOKENIZERS_PARALLELISM warning is handled
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def find_fastest_batch_size(model_name, exp_batch_sizes, max_length, patience=2):
"""
Uses a subprocess to find the batch size with the highest throughput,
with an exponential search and early stopping.
"""
results = []
python_executable = sys.executable
best_throughput = 0.0
patience_counter = 0
for exp in range(exp_batch_sizes[0], exp_batch_sizes[1] + 1):
bs = 2 ** exp
print(f"\n--- 🧪 Testing batch_size = {bs} ---")
# The subprocess call is now simpler, passing only the 3 required arguments
result = subprocess.run(
[python_executable, "test_throughput_worker.py", str(bs), model_name, str(max_length), final_data_path],
capture_output=True,
text=True
)
if result.returncode == 0:
try:
output_lines = result.stdout.strip().splitlines()
throughput = float(output_lines[-1])
print(f"✅ Success! Throughput: {throughput:.2f} samples/sec")
results.append({"batch_size": bs, "throughput": throughput})
if throughput > best_throughput:
best_throughput = throughput
patience_counter = 0
else:
patience_counter += 1
print(f"📉 Throughput did not improve. Patience: {patience_counter}/{patience}")
if patience_counter >= patience:
print("\nStopping early due to diminishing returns.")
break
except (ValueError, IndexError):
print("❌ Failed to parse throughput from script output.")
print(f"Full output: {result.stdout}")
break
else:
print(f"❌ Failed! Batch size {bs} is too large. Stopping test.")
print(f"Error log: {result.stderr}")
break
if not results:
print("No batch size succeeded.")
return 0
results_df = pd.DataFrame(results)
optimal_row = results_df.loc[results_df['throughput'].idxmax()]
fastest_batch_size = int(optimal_row['batch_size'])
print("\n--- Throughput Test Results ---")
print(results_df)
print("\n==============================================")
print(f"🏆 Fastest batch size found: {fastest_batch_size}")
print("==============================================")
return fastest_batch_size
# --- Run the throughput test ---
exp_batch_sizes = (0, 9) # Test powers of 2 from 1 to 512
optimal_batch_size = find_fastest_batch_size(model_name, exp_batch_sizes, max_length=max_length, patience=2)
MODEL_NAME = "HuggingFaceTB/SmolLM3-3B-Base"
DATASET_NAME = "HuggingFaceTB/smoltalk2"
DATASET_SPLIT = "OpenHermes_2.5_no_think"
TRIALS = 10
STUDY_DB = "sqlite:///hpo_study.db"
OUT_DIR = "./hpo_outputs"
script_path = "./run_hpo.py"
cmd = (
f"python {script_path} "
f"--batch-size {optimal_batch_size} "
f"--max-length {max_length} "
f"--max-steps 20 "
f"--epochs 0.2 "
f"--model-name \"{MODEL_NAME}\" "
f"--dataset-name \"{DATASET_NAME}\" "
f"--dataset-split \"{DATASET_SPLIT}\" "
f"--trials {TRIALS} "
f"--study-db \"{STUDY_DB}\" "
f"--output-dir \"{OUT_DIR}\""
)
print("Running:", cmd)
!{cmd}

side quest: HPO and batch size

To follow along you'll need to install optuna.

Before we do real hpo let's just look for an efficient batch size for the current machine:

batch_size: this is determined to be the maximum power of 2 (for no particular reason for now) that shows improved samples/second processing.

we need a max_length for this because of how batches are handled when training. The training process automatically pads the data for you on-the-fly for every single batch.

The samples are all of different lengths, but padding happens in memory during the training loop.


The Role of the Data Collator

This automatic padding is handled by a component called a Data Collator. The SFTTrainer uses one by default. Here’s how it works for every single step of training:

  1. The DataLoader grabs a small group of samples from your dataset (e.g., a batch of 16). These samples all have different lengths.

  2. This group is passed to the Data Collator.

  3. The Data Collator finds the longest sample in that specific group.

  4. It adds padding tokens (using your tokenizer.pad_token) to all the shorter samples until they all match the length of that longest sample.

  5. Finally, it stacks them into a single, rectangular tensor that can be efficiently processed by the hardware.

This all happens "in-memory" for each batch and does not alter your original dataset. By setting max_length, we are just giving this automatic process a "ceiling" to ensure the memory usage is predictable.

#!/usr/bin/env python3
"""
run_hpo.py — Hyperparameter optimization for SmolLM3-3B fine-tuning.
Each Optuna trial runs in its own process to avoid MPS memory accumulation.
"""
import argparse
import os
import time
import gc
import torch
import optuna
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer
from peft import LoraConfig
from datasets import load_dataset
# =====================================================
# ARGUMENTS
# =====================================================
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--batch-size", type=int, required=True)
p.add_argument("--max-length", type=int, required=True)
p.add_argument("--max-steps", type=int, required=False)
p.add_argument("--epochs", type=float, required=False)
p.add_argument("--model-name", type=str, default="HuggingFaceTB/SmolLM3-3B-Base")
p.add_argument("--dataset-name", type=str, default="HuggingFaceTB/smoltalk2")
p.add_argument("--dataset-split", type=str, default="OpenHermes_2.5_no_think")
p.add_argument("--trials", type=int, default=10)
p.add_argument("--study-db", type=str, default="sqlite:///hpo_study.db")
p.add_argument("--output-dir", type=str, default="./hpo_outputs")
p.add_argument(
"--run-one", action="store_true", help="internal flag: run a single trial"
)
return p.parse_args()
# =====================================================
# DATA + TOKENIZER
# =====================================================
def prepare_data_and_tokenizer(model_name, dataset_name, dataset_split):
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.chat_template is None:
tokenizer.chat_template = (
"{% for message in messages %}\n"
"{% if message['role'] == 'user' %}\n"
"{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>' }}\n"
"{% elif message['role'] == 'assistant' %}\n"
"{{ '<|im_start|>assistant\n' + message['content'] + '<|im_end|>' }}\n"
"{% endif %}\n"
"{% endfor %}"
)
dataset_dict = load_dataset(dataset_name, "SFT")
hpo_dataset = dataset_dict[dataset_split]
split = hpo_dataset.train_test_split(test_size=0.2, seed=42)
train_dataset = split["train"].select(range(1000))
eval_dataset = split["test"].select(range(200))
return tokenizer, train_dataset, eval_dataset
# =====================================================
# OBJECTIVE FUNCTION
# =====================================================
def objective_factory(args):
def objective(trial: optuna.trial.Trial):
tokenizer, train_dataset, eval_dataset = prepare_data_and_tokenizer(
args.model_name, args.dataset_name, args.dataset_split
)
learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-4, log=True)
num_train_epochs = (
args.epochs if args.epochs else trial.suggest_int("num_train_epochs", 1, 3)
)
lora_alpha = trial.suggest_int("lora_alpha", 8, 32)
max_steps = args.max_steps if args.max_steps else -1
gc.collect()
try:
torch.mps.empty_cache()
except Exception:
pass
print(
f"\n=== Trial {trial.number}: lr={learning_rate:.2e}, epochs={num_train_epochs}, lora_alpha={lora_alpha} ==="
)
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
peft_config = LoraConfig(
r=8,
lora_alpha=lora_alpha,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
training_args = TrainingArguments(
output_dir=os.path.join(args.output_dir, f"hpo_trial_{trial.number}"),
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
num_train_epochs=num_train_epochs,
max_steps=max_steps,
learning_rate=learning_rate,
eval_strategy="no", # the early-step loss landscape sufficiently reflects which parameters are better
save_strategy="no",
logging_steps=50,
report_to="none",
bf16=True,
gradient_accumulation_steps=1,
dataloader_num_workers=0,
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
processing_class=tokenizer,
)
trainer.train()
eval_loss = trainer.evaluate()["eval_loss"]
del model, trainer
gc.collect()
try:
torch.mps.empty_cache()
except Exception:
pass
time.sleep(2)
return eval_loss
return objective
# =====================================================
# MODES
# =====================================================
def run_single_trial_mode(args):
"""Child mode: run one trial."""
objective = objective_factory(args)
study = optuna.load_study(study_name="hpo_study", storage=args.study_db)
study.optimize(objective, n_trials=1, catch=(RuntimeError,))
def run_manager_mode(args):
"""Manager mode: spawn subprocess for each trial."""
optuna.create_study(
study_name="hpo_study",
direction="minimize",
storage=args.study_db,
load_if_exists=True,
)
for i in range(args.trials):
print(f"\n=== Starting HPO trial {i} ===")
cmd = (
f"{os.sys.executable} {os.path.abspath(__file__)} "
f"--batch-size {args.batch_size} "
f"--epochs {args.epochs} "
f"--max-steps {args.max_steps} "
f"--max-length {args.max_length} "
f'--model-name "{args.model_name}" '
f'--dataset-name "{args.dataset_name}" '
f'--dataset-split "{args.dataset_split}" '
f'--trials 1 --study-db "{args.study_db}" '
f'--output-dir "{args.output_dir}" --run-one'
)
ret = os.system(cmd)
if ret != 0:
print(f"Trial {i} failed (exit {ret}).")
time.sleep(2)
study = optuna.load_study(study_name="hpo_study", storage=args.study_db)
print("\n--- Best Hyperparameters Found ---")
print(study.best_params)
# =====================================================
# MAIN
# =====================================================
def main():
args = parse_args()
if args.run_one:
run_single_trial_mode(args)
else:
run_manager_mode(args)
if __name__ == "__main__":
main()
import sys
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
from datasets import load_from_disk
NUM_STEPS_FOR_TEST = 20
def run_test(batch_size: int, model_name: str, max_length: int, dataset_path: str):
try:
num_samples_needed = batch_size * NUM_STEPS_FOR_TEST
full_test_dataset = load_from_disk(dataset_path)
if len(full_test_dataset) < num_samples_needed:
sample_dataset = full_test_dataset
else:
sample_dataset = full_test_dataset.select(range(num_samples_needed))
config = SFTConfig(
output_dir="./test_output",
per_device_train_batch_size=batch_size,
max_steps=NUM_STEPS_FOR_TEST,
logging_steps=10,
report_to="none",
max_length=max_length,
)
model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
trainer = SFTTrainer(
model=model,
args=config,
train_dataset=sample_dataset,
processing_class=tokenizer,
)
result = trainer.train()
samples_per_second = result.metrics["train_samples_per_second"]
print(f"{samples_per_second:.2f}")
except Exception as e:
print(f"Error during test: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
bs, mn, ml, dp = int(sys.argv[1]), sys.argv[2], int(sys.argv[3]), sys.argv[4]
run_test(bs, mn, ml, dp)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment