Skip to content

Instantly share code, notes, and snippets.

@mooreniemi
Last active July 3, 2025 03:31
Show Gist options
  • Save mooreniemi/deee670da317769801fdb6e2e3ac06f1 to your computer and use it in GitHub Desktop.
Save mooreniemi/deee670da317769801fdb6e2e3ac06f1 to your computer and use it in GitHub Desktop.
basic perf
from sentence_transformers import CrossEncoder
import torch
import time
import numpy as np
import copy
import os
import multiprocessing
import subprocess
import sys
os.environ["ORT_DISABLE_MEMORY_ARENA"] = "1"
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import onnxruntime as ort
from optimum.exporters.onnx import main_export
import json
from onnxruntime.quantization import quantize_dynamic, QuantType
from onnxruntime.quantization import quantize_static, CalibrationDataReader, CalibrationMethod
class BertCalibrationDataReader(CalibrationDataReader):
def __init__(self, encodings):
self.inputs = []
for i in range(0, len(encodings["input_ids"])):
self.inputs.append({
"input_ids": encodings["input_ids"][i:i+1], # Add batch dimension
"attention_mask": encodings["attention_mask"][i:i+1] # Add batch dimension
})
self.iterator = iter(self.inputs)
def get_next(self):
return next(self.iterator, None)
# Load query-document pairs from JSON file
with open('query_doc_pairs.json', 'r') as f:
pairs_data = json.load(f)
# Convert to list of tuples for the benchmark
base_queries = [(pair['query'], pair['document']) for pair in pairs_data]
# Duplicate pairs to get 100 total (or more if needed)
target_pairs = 10#0
repetitions = (target_pairs + len(base_queries) - 1) // len(base_queries) # Ceiling division
queries = base_queries * repetitions
print(f"Loaded {len(base_queries)} unique query-document pairs")
print(f"Duplicated to {len(queries)} total pairs for benchmarking")
results = []
def print_header(title):
print("=" * len(title))
print(title)
print("=" * len(title))
def create_model_with_threads(device, num_threads):
"""Create a new model instance with specific threading configuration"""
# Set environment variables for PyTorch threading
os.environ["OMP_NUM_THREADS"] = str(num_threads)
os.environ["MKL_NUM_THREADS"] = str(num_threads)
# Create a fresh model instance
return CrossEncoder("tomaarsen/reranker-modernbert-base-msmarco-bce", device=device)
def benchmark(model, queries, batch_size=None, label=None, num_threads=None):
total_tokens = 0
for query, document in queries:
tokens = model.tokenizer.encode(query, document, add_special_tokens=True)
total_tokens += len(tokens)
start_time = time.time()
if batch_size is None:
scores = [model.predict([pair])[0] for pair in queries]
else:
scores = model.predict(copy.deepcopy(queries), batch_size=batch_size)
elapsed = time.time() - start_time
num_pairs = len(queries)
pairs_per_sec = num_pairs / elapsed
print_header(f"{label}")
print(f"Total tokens: {total_tokens}")
print(f"Query-document pairs: {num_pairs}")
print(f"Inference time: {elapsed:.4f} seconds")
print(f"Tokens per second: {total_tokens / elapsed:.2f}")
print(f"Pairs per second: {pairs_per_sec:.2f}")
print(f"Scores: {np.round(scores, 4)}")
print()
results.append({
"label": label.split('[')[0].strip(),
"device": label.split('[')[1].strip(']') if '[' in label else "unknown",
"tokens": total_tokens,
"time_sec": round(elapsed, 4),
"tokens_per_sec": round(total_tokens / elapsed, 2),
"pairs_per_sec": round(pairs_per_sec, 2),
"num_threads": num_threads if num_threads is not None else multiprocessing.cpu_count(),
})
def run_onnx_inference(queries, model_path, label, provider, threads=None):
so = ort.SessionOptions()
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
so.log_severity_level = 1
if threads is not None:
so.intra_op_num_threads = threads
so.inter_op_num_threads = 1
session = ort.InferenceSession(model_path, sess_options=so, providers=[provider])
tokenizer = AutoTokenizer.from_pretrained("tomaarsen/reranker-modernbert-base-msmarco-bce")
# Tokenize inputs
encoded = tokenizer([q for q, d in queries], [d for q, d in queries], return_tensors="np", padding=True, truncation=True)
inputs = {k: v for k, v in encoded.items() if k in ['input_ids', 'attention_mask']}
total_tokens = int(np.sum([len(x) for x in encoded["input_ids"]]))
start = time.time()
outputs = session.run(None, inputs)
elapsed = time.time() - start
# Extract scores from ONNX output (logits)
logits = outputs[0] # Shape: (batch_size, num_classes)
scores = 1 / (1 + np.exp(-logits)) # Sigmoid for binary classification
scores = scores.flatten() # Flatten to 1D array
num_pairs = len(queries)
pairs_per_sec = num_pairs / elapsed
print_header(f"{label}")
print(f"Total tokens: {total_tokens}")
print(f"Query-document pairs: {num_pairs}")
print(f"Inference time: {elapsed:.4f} seconds")
print(f"Tokens per second: {total_tokens / elapsed:.2f}")
print(f"Pairs per second: {pairs_per_sec:.2f}")
print(f"Scores: {np.round(scores, 4)}")
print()
results.append({
"label": label.split('[')[0].strip(),
"device": f"ONNX {provider.replace('ExecutionProvider', '')}",
"tokens": total_tokens,
"time_sec": round(elapsed, 4),
"tokens_per_sec": round(total_tokens / elapsed, 2),
"pairs_per_sec": round(pairs_per_sec, 2),
"num_threads": threads if threads is not None else multiprocessing.cpu_count(),
})
def run_benchmark_with_threads(label, batch_size=None, num_threads=None):
"""Run benchmark with specific threading configuration"""
if torch.cuda.is_available():
try:
model_gpu = create_model_with_threads("cuda", num_threads) if num_threads else CrossEncoder("tomaarsen/reranker-modernbert-base-msmarco-bce", device="cuda")
benchmark(model_gpu, queries, batch_size=batch_size, label=f"{label} [GPU]", num_threads=num_threads)
except Exception as e:
print_header(f"{label} [GPU]")
print(f"Skipped due to error: {e}\n")
else:
print_header(f"{label} [GPU]")
print("Skipped: CUDA not available\n")
try:
model_cpu = create_model_with_threads("cpu", num_threads) if num_threads else CrossEncoder("tomaarsen/reranker-modernbert-base-msmarco-bce", device="cpu")
benchmark(model_cpu, queries, batch_size=batch_size, label=f"{label} [CPU]", num_threads=num_threads)
except Exception as e:
print_header(f"{label} [CPU]")
print(f"Skipped due to error: {e}\n")
# Run benchmarks with different threading configurations
run_benchmark_with_threads("Stage 0: Inference (no batching)", batch_size=None)
run_benchmark_with_threads("Stage 1: Inference (manual batching)", batch_size=2)
onnx_path = Path("model.onnx")
if not onnx_path.exists():
print("Exporting ONNX model...")
tokenizer = AutoTokenizer.from_pretrained("tomaarsen/reranker-modernbert-base-msmarco-bce")
model = AutoModelForSequenceClassification.from_pretrained("tomaarsen/reranker-modernbert-base-msmarco-bce")
main_export(
model_name_or_path="tomaarsen/reranker-modernbert-base-msmarco-bce",
output=".",
opset=14,
task="text-classification",
do_validation=False,
)
int8_path = Path("model.int8.onnx")
if not int8_path.exists():
print("Quantizing ONNX model to int8...")
quantize_dynamic(str(onnx_path), str(int8_path), weight_type=QuantType.QInt8)
static_int8_path = Path("model.static.int8.onnx")
if not static_int8_path.exists():
print("Static quantizing ONNX model to int8...")
tokenizer = AutoTokenizer.from_pretrained("tomaarsen/reranker-modernbert-base-msmarco-bce")
encoded = tokenizer(
[q for q, d in queries[:20]],
[d for q, d in queries[:20]],
padding=True,
truncation=True,
return_tensors="np"
)
calibration_data_reader = BertCalibrationDataReader(encoded)
quantize_static(
model_input=str(onnx_path),
model_output=str(static_int8_path),
calibration_data_reader=calibration_data_reader,
quant_format=QuantType.QDQ
)
try:
run_onnx_inference(queries, str(onnx_path), label="Stage 2: ONNX Inference (float32)", provider="CUDAExecutionProvider")
except Exception as e:
print_header("Stage 2: ONNX Inference (float32) [GPU]")
print(f"Skipped due to error: {e}\n")
run_onnx_inference(queries, str(onnx_path), label="Stage 2: ONNX Inference (float32)", provider="CPUExecutionProvider")
print_header("Stage 3: Inference with Threading = 1")
run_benchmark_with_threads("Stage 3: Threads = 1", batch_size=2, num_threads=1)
try:
run_onnx_inference(queries, str(onnx_path), label="Stage 3: ONNX Inference [GPU] (1 thread)", provider="CUDAExecutionProvider", threads=1)
except Exception as e:
print_header("Stage 3: ONNX Inference [GPU] (1 thread)")
print(f"Skipped due to error: {e}\n")
run_onnx_inference(queries, str(onnx_path), label="Stage 3: ONNX Inference [CPU] (1 thread)", provider="CPUExecutionProvider", threads=1)
print_header("Stage 4: Inference with Max Threads")
max_threads = multiprocessing.cpu_count()
run_benchmark_with_threads("Stage 4: Threads = Max", batch_size=2, num_threads=max_threads)
try:
run_onnx_inference(queries, str(onnx_path), label="Stage 4: ONNX Inference [GPU] (max threads)", provider="CUDAExecutionProvider", threads=max_threads)
except Exception as e:
print_header("Stage 4: ONNX Inference [GPU] (max threads)")
print(f"Skipped due to error: {e}\n")
run_onnx_inference(queries, str(onnx_path), label="Stage 4: ONNX Inference [CPU] (max threads)", provider="CPUExecutionProvider", threads=max_threads)
print_header("Stage 5: ONNX Inference (int8) with Threading = 1")
run_onnx_inference(queries, str(int8_path), label="Stage 5: ONNX Inference (int8)", provider="CPUExecutionProvider", threads=1)
print_header("Stage 6: ONNX Inference (int8) with Max Threads")
run_onnx_inference(queries, str(int8_path), label="Stage 6: ONNX Inference (int8)", provider="CPUExecutionProvider", threads=max_threads)
print_header("Stage 7: ONNX Inference (static int8) with Max Threads")
run_onnx_inference(queries, str(static_int8_path), label="Stage 7: ONNX Inference (static int8)", provider="CPUExecutionProvider", threads=max_threads)
import pandas as pd
df = pd.DataFrame(results)
df_sorted = df.sort_values('tokens_per_sec', ascending=False)
print("\n===== Summary Table (Sorted by Tokens/Second) =====")
print(df_sorted)
===== Summary Table (Sorted by Tokens/Second) =====
label device tokens time_sec tokens_per_sec pairs_per_sec num_threads
6 Stage 3: Threads = 1 GPU 4040 0.3483 11600.09 60.30 1
10 Stage 4: Threads = Max GPU 4040 0.3693 10940.52 56.87 16
2 Stage 1: Inference (manual batching) GPU 4040 0.5393 7491.84 38.94 16
0 Stage 0: Inference (no batching) GPU 4040 1.1290 3578.29 18.60 16
15 Stage 6: ONNX Inference (int8) ONNX CPU 4557 2.4315 1874.14 8.64 16
11 Stage 4: Threads = Max CPU 4040 2.3970 1685.46 8.76 16
7 Stage 3: Threads = 1 CPU 4040 2.4063 1678.89 8.73 1
3 Stage 1: Inference (manual batching) CPU 4040 2.4304 1662.29 8.64 16
1 Stage 0: Inference (no batching) CPU 4040 2.5697 1572.19 8.17 16
5 Stage 2: ONNX Inference (float32) ONNX CPU 4557 3.1394 1451.57 6.69 16
12 Stage 4: ONNX Inference ONNX CUDA 4557 3.1803 1432.86 6.60 16
13 Stage 4: ONNX Inference ONNX CPU 4557 3.1828 1431.75 6.60 16
4 Stage 2: ONNX Inference (float32) ONNX CUDA 4557 3.2055 1421.61 6.55 16
14 Stage 5: ONNX Inference (int8) ONNX CPU 4557 4.2853 1063.39 4.90 1
16 Stage 7: ONNX Inference (static int8) ONNX CPU 4557 4.8013 949.12 4.37 16
9 Stage 3: ONNX Inference ONNX CPU 4557 9.9935 456.00 2.10 1
8 Stage 3: ONNX Inference ONNX CUDA 4557 10.0227 454.67 2.10 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment