Skip to content

Instantly share code, notes, and snippets.

@corbt
Last active June 13, 2025 12:42
Show Gist options
  • Save corbt/25eb21d47bbd2add810f1cf0edeea1b6 to your computer and use it in GitHub Desktop.
Save corbt/25eb21d47bbd2add810f1cf0edeea1b6 to your computer and use it in GitHub Desktop.
Benchmark script for reward model performance
Strategy | Relative Throughput | Time (s) | Cost ($/M tokens)
----------------------------------------------------------------------------------------
Unsloth | 2.17 | 3.83 | $0.0188
Unsloth+PEFT | 1.58 | 5.27 | $0.0259
Transformers+Liger | 1.14 | 7.28 | $0.0358
vLLM | 1.00 | 8.31 | $0.0409
Transformers | 0.97 | 8.54 | $0.0420
Transformers+Liger+PEFT | 0.84 | 9.85 | $0.0484
Transformers+PEFT | 0.74 | 11.26 | $0.0554
import time
import random
import string
import torch
import numpy as np
from tqdm import tqdm
from typing import List, Tuple
import os
# For VLLM
from vllm import LLM
# For Hugging Face Transformers
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import BitsAndBytesConfig
from peft import PeftModel, LoraConfig, get_peft_model
# For Liger Kernel - import directly, will raise ImportError if not found
from liger_kernel.transformers import AutoLigerKernelForCausalLM
# For Unsloth
from unsloth import FastLanguageModel # type: ignore
# Set seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
# Constants
MODEL_NAME = os.environ.get("MODEL_NAME")
NUM_SAMPLES = 100
CHARS_PER_SAMPLE = 6000
WARMUP_ITERATIONS = 5
# Cost Calculation Constants
GPU_COST_PER_HOUR = 4.0 # Dollars per hour
GPU_UTILIZATION = 0.50 # 50% utilization assumption
EFFECTIVE_GPU_COST_PER_HOUR = GPU_COST_PER_HOUR / GPU_UTILIZATION
EFFECTIVE_GPU_COST_PER_SECOND = EFFECTIVE_GPU_COST_PER_HOUR / 3600
def generate_random_samples(num_samples: int, chars_per_sample: int) -> List[str]:
"""Generate random text samples with a fixed seed for reproducibility."""
random.seed(SEED)
samples = []
for _ in range(num_samples):
# Generate random string of specified length
text = "".join(
random.choice(
string.ascii_letters + string.digits + string.punctuation + " "
)
for _ in range(chars_per_sample)
)
samples.append(text)
return samples
def count_tokens(tokenizer, samples: List[str]) -> int:
"""Count the total number of tokens in all samples."""
total_tokens = 0
for sample in samples:
total_tokens += len(tokenizer.encode(sample))
return total_tokens
class BenchmarkVLLM:
"""Benchmark using vLLM to serve the reward model."""
def __init__(self, model_name: str):
self.name = "vLLM Reward Model"
# Initialize vLLM with the reward modeling task
self.model = LLM(model=model_name, task="reward", dtype="half")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Ensure padding token is set
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def run(self, samples: List[str]) -> Tuple[float, float]:
"""Run inference on samples and measure throughput and total time."""
start_time = time.time()
outputs = self.model.encode(samples)
end_time = time.time()
total_time = end_time - start_time
throughput = len(samples) / total_time if total_time > 0 else 0
return throughput, total_time
class BenchmarkTransformers:
"""Benchmark using Transformers to serve the reward model."""
def __init__(self, model_name: str):
self.name = "Transformers Reward Model"
# Initialize standard transformers model
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Ensure padding token is set
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=1, torch_dtype=torch.float16
).cuda()
self.model.config.pad_token_id = self.tokenizer.eos_token_id
def run(self, samples: List[str]) -> Tuple[float, float]:
"""Run inference on samples and measure throughput and total time."""
start_time = time.time()
with torch.no_grad():
for sample in tqdm(samples, desc="Transformers"):
inputs = self.tokenizer(
[sample], return_tensors="pt", padding=True, truncation=True
).to("cuda")
outputs = self.model(**inputs).logits
end_time = time.time()
total_time = end_time - start_time
throughput = len(samples) / total_time if total_time > 0 else 0
return throughput, total_time
class BenchmarkTransformersPeft:
"""Benchmark using Transformers with PEFT adapter to serve the reward model."""
def __init__(self, model_name: str):
self.name = "Transformers + PEFT Reward Model"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Ensure padding token is set
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Initialize base model
base_model = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=1, torch_dtype=torch.float16
).cuda()
base_model.config.pad_token_id = self.tokenizer.eos_token_id
# Configure and add LoRA adapter
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="SEQ_CLS",
)
# Apply LoRA adapter to model
self.model = get_peft_model(base_model, lora_config)
def run(self, samples: List[str]) -> Tuple[float, float]:
"""Run inference on samples and measure throughput and total time."""
start_time = time.time()
with torch.no_grad():
for sample in tqdm(samples, desc=f"Transformers+PEFT"):
inputs = self.tokenizer(
[sample], return_tensors="pt", padding=True, truncation=True
).to("cuda")
outputs = self.model(**inputs).logits
end_time = time.time()
total_time = end_time - start_time
throughput = len(samples) / total_time if total_time > 0 else 0
return throughput, total_time
class BenchmarkTransformersLiger:
"""Benchmark using Transformers with Liger Kernel optimizations via AutoLigerKernelForCausalLM."""
def __init__(self, model_name: str):
self.name = "Liger Reward Model"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Ensure padding token is set
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Use AutoLigerKernelForCausalLM to automatically apply the appropriate Liger kernels
self.model = AutoLigerKernelForCausalLM.from_pretrained(
model_name, num_labels=1, torch_dtype=torch.float16
).cuda()
self.model.config.pad_token_id = self.tokenizer.eos_token_id
def run(self, samples: List[str]) -> Tuple[float, float]:
"""Run inference on samples and measure throughput and total time."""
start_time = time.time()
with torch.no_grad():
for sample in tqdm(samples, desc="Liger"):
inputs = self.tokenizer(
[sample], return_tensors="pt", padding=True, truncation=True
).to("cuda")
outputs = self.model(**inputs).logits
end_time = time.time()
total_time = end_time - start_time
throughput = len(samples) / total_time if total_time > 0 else 0
return throughput, total_time
class BenchmarkTransformersLigerPeft:
"""Benchmark using Transformers with both Liger Kernel optimizations and PEFT adapter."""
def __init__(self, model_name: str):
self.name = "Liger + PEFT Reward Model"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Ensure padding token is set
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# First apply Liger kernel optimizations via AutoLigerKernelForCausalLM
base_model = AutoLigerKernelForCausalLM.from_pretrained(
model_name, num_labels=1, torch_dtype=torch.float16
).cuda()
base_model.config.pad_token_id = self.tokenizer.eos_token_id
# Then apply PEFT/LoRA adapter
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="SEQ_CLS",
)
# Apply LoRA adapter to model
self.model = get_peft_model(base_model, lora_config)
def run(self, samples: List[str]) -> Tuple[float, float]:
"""Run inference on samples and measure throughput and total time."""
start_time = time.time()
with torch.no_grad():
for sample in tqdm(samples, desc="Liger+PEFT"):
inputs = self.tokenizer(
[sample], return_tensors="pt", padding=True, truncation=True
).to("cuda")
outputs = self.model(**inputs).logits
end_time = time.time()
total_time = end_time - start_time
throughput = len(samples) / total_time if total_time > 0 else 0
return throughput, total_time
# NEW CLASS: Benchmark using Unsloth FastLanguageModel
class BenchmarkUnsloth:
"""Benchmark using Unsloth FastLanguageModel to serve the reward model."""
def __init__(self, model_name: str):
self.name = "Unsloth Reward Model"
# Load model & tokenizer with Unsloth optimizations (4bit for memory efficiency)
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=2048,
load_in_4bit=False,
load_in_8bit=False,
full_finetuning=False,
num_labels=1,
)
# Ensure padding token is set
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Move the model to GPU (FastLanguageModel may already do this, but ensure)
self.model = self.model.cuda()
self.model.eval()
def run(self, samples: List[str]) -> Tuple[float, float]:
"""Run inference on samples and measure throughput and total time."""
start_time = time.time()
with torch.no_grad():
for sample in tqdm(samples, desc="Unsloth"):
inputs = self.tokenizer(
[sample],
return_tensors="pt",
padding=True,
truncation=True,
).to("cuda")
# Forward pass; output could be CausalLMOutput or similar
_ = self.model(**inputs)
end_time = time.time()
total_time = end_time - start_time
throughput = len(samples) / total_time if total_time > 0 else 0
return throughput, total_time
# NEW CLASS: Benchmark using Unsloth FastLanguageModel with a LoRA adapter
class BenchmarkUnslothPeft:
"""Benchmark using Unsloth FastLanguageModel *plus* a PEFT/LoRA adapter."""
def __init__(self, model_name: str):
self.name = "Unsloth + PEFT Reward Model"
# Load base model with Unsloth optimizations
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=2048,
load_in_4bit=False,
load_in_8bit=False,
full_finetuning=False,
num_labels=1,
)
# Ensure padding token
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Apply LoRA/PEFT adapter via Unsloth helper
self.model = FastLanguageModel.get_peft_model(
self.model,
r=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=16,
lora_dropout=0.0,
bias="none",
)
self.model = self.model.cuda()
self.model.eval()
def run(self, samples: List[str]) -> Tuple[float, float]:
"""Run inference on samples and measure throughput and total time."""
start_time = time.time()
with torch.no_grad():
for sample in tqdm(samples, desc="Unsloth+PEFT"):
inputs = self.tokenizer(
[sample],
return_tensors="pt",
padding=True,
truncation=True,
).to("cuda")
_ = self.model(**inputs)
end_time = time.time()
total_time = end_time - start_time
throughput = len(samples) / total_time if total_time > 0 else 0
return throughput, total_time
def run_benchmark(benchmark_class, samples, model_name: str) -> Tuple[float, float]:
"""Run benchmark for a given implementation, returning throughput and total time."""
benchmark = benchmark_class(model_name=model_name)
print(f"\n{'-' * 20} Running {benchmark.name} {'-' * 20}")
# Warmup
print(f"Warming up {benchmark.name}...")
warmup_samples = samples[:WARMUP_ITERATIONS]
benchmark.run(warmup_samples) # Ignore warmup results
# Actual benchmarking
throughput, total_time = benchmark.run(samples)
print(f"{benchmark.name}: {throughput:.2f} samples/second")
print(f"{benchmark.name}: Total time: {total_time:.2f} seconds")
return throughput, total_time
def main():
# Add check for MODEL_NAME
if MODEL_NAME is None:
raise ValueError(
"MODEL_NAME environment variable not set. Please set it before running."
)
# Generate samples
print(
f"Generating {NUM_SAMPLES} random samples with {CHARS_PER_SAMPLE} characters each..."
)
samples = generate_random_samples(NUM_SAMPLES, CHARS_PER_SAMPLE)
# Tokenize and count tokens
print("Counting tokens...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Ensure padding token is set for token counting too
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
total_tokens = count_tokens(tokenizer, samples)
avg_tokens = total_tokens / NUM_SAMPLES
print(f"Average tokens per sample: {avg_tokens:.2f}")
print(f"Total tokens: {total_tokens}")
# Run all benchmarks
all_results = {} # Stores { name: (throughput, total_time) }
# 1. vLLM (run first before any global patches)
vllm_throughput, vllm_time = run_benchmark(BenchmarkVLLM, samples, MODEL_NAME)
all_results["vLLM"] = (vllm_throughput, vllm_time)
# 2. Transformers
transformers_throughput, transformers_time = run_benchmark(
BenchmarkTransformers, samples, MODEL_NAME
)
all_results["Transformers"] = (transformers_throughput, transformers_time)
# 3. Transformers + PEFT
peft_throughput, peft_time = run_benchmark(
BenchmarkTransformersPeft, samples, MODEL_NAME
)
all_results["Transformers+PEFT"] = (peft_throughput, peft_time)
# 4. Liger
liger_throughput, liger_time = run_benchmark(
BenchmarkTransformersLiger, samples, MODEL_NAME
)
all_results["Liger"] = (liger_throughput, liger_time)
# 5. Liger + PEFT
liger_peft_throughput, liger_peft_time = run_benchmark(
BenchmarkTransformersLigerPeft, samples, MODEL_NAME
)
all_results["Liger+PEFT"] = (
liger_peft_throughput,
liger_peft_time,
)
# 6. Unsloth (run LAST before any other Unsloth-based variants)
unsloth_throughput, unsloth_time = run_benchmark(
BenchmarkUnsloth, samples, MODEL_NAME
)
all_results["Unsloth"] = (unsloth_throughput, unsloth_time)
# 7. Unsloth + PEFT (also relies on Unsloth's global patches)
unsloth_peft_throughput, unsloth_peft_time = run_benchmark(
BenchmarkUnslothPeft, samples, MODEL_NAME
)
all_results["Unsloth+PEFT"] = (unsloth_peft_throughput, unsloth_peft_time)
# --- Final Performance Summary (single concise table) ---
print("\n" + "=" * 50)
print("FINAL PERFORMANCE SUMMARY (normalized to vLLM)")
print("=" * 50)
# Fetch baseline (vLLM) throughput for normalization
vllm_throughput_baseline = all_results.get("vLLM", (None, None))[0]
if vllm_throughput_baseline is None or vllm_throughput_baseline == 0:
raise ValueError("vLLM results missing or invalid — cannot normalize to vLLM.")
# Updated column order: Relative Throughput is now the second column
print(
f"{'Strategy':<30} | {'Relative Throughput':<22} | {'Time (s)':<12} | {'Cost ($/M tokens)':<18}"
)
print("-" * 110)
# Sort rows by total time ascending (fastest first for readability)
for name, (throughput, total_time) in sorted(
all_results.items(), key=lambda x: x[1][1]
):
relative_throughput = (
throughput / vllm_throughput_baseline if vllm_throughput_baseline > 0 else 0
)
# Calculate cost per million tokens for this run
total_cost_for_run = total_time * EFFECTIVE_GPU_COST_PER_SECOND
cost_per_token = (
total_cost_for_run / total_tokens if total_tokens > 0 else float("inf")
)
cost_per_million_tokens = cost_per_token * 1_000_000
# Follow the new column order in the table rows as well
print(
f"{name:<30} | {relative_throughput:<30.2f} | {total_time:<12.2f} | ${cost_per_million_tokens:<17.4f}"
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment