Skip to content

Instantly share code, notes, and snippets.

@jeffmylife
Created January 2, 2025 21:13
Show Gist options
  • Save jeffmylife/648b8758a31c2283895a14441b311803 to your computer and use it in GitHub Desktop.
Save jeffmylife/648b8758a31c2283895a14441b311803 to your computer and use it in GitHub Desktop.
Latency comparison command line tool between LLM's
import argparse
import json
import os
import statistics
import time
from statistics import mean, stdev
from typing import List, Dict
import plotext as plt
from litellm import completion
from rich.console import Console
from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.table import Table
console = Console()
# Configure models to test
MODELS = [
# OpenAI
"gpt-3.5-turbo",
"gpt-4o-mini",
# Anthropic
"claude-3-haiku-20240307",
"claude-3-5-haiku-20241022",
# Gemini
"gemini/gemini-1.5-flash-latest",
"gemini/gemini-1.5-flash-8b",
"gemini/gemini-2.0-flash-exp",
]
def measure_latency(model: str, n_calls: int = 5) -> List[float]:
"""Measure response latency for a given model"""
latencies = []
messages = [{"role": "user", "content": "Respond with just one word: hello"}]
# Suppress litellm provider list messages
os.environ['LITELLM_SUPPRESS_PROVIDER_LIST'] = 'true'
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
transient=True,
) as progress:
task = progress.add_task(f"Testing {model}...", total=n_calls)
for _ in range(n_calls):
try:
start = time.time()
completion(model=model, messages=messages)
end = time.time()
latency = end - start
latencies.append(latency)
except Exception as e:
console.print(f"[red]Error with model {model}: {str(e)}[/red]")
continue
finally:
progress.advance(task)
return latencies
def run_latency_test(n_replicates: int = 5) -> Dict:
"""Run latency tests across all models"""
results = {}
all_latencies = {}
for model in MODELS:
console.print(Panel(f"Testing {model} 🔄", style="blue"))
latencies = measure_latency(model, n_replicates)
all_latencies[model] = latencies
if latencies:
results[model] = {
"mean_latency": round(mean(latencies), 3),
"std_latency": round(stdev(latencies), 3) if len(latencies) > 1 else 0,
"min_latency": round(min(latencies), 3),
"max_latency": round(max(latencies), 3),
"successful_calls": len(latencies),
"total_calls": n_replicates
}
else:
results[model] = {
"error": "All calls failed"
}
return results, all_latencies
def main():
parser = argparse.ArgumentParser(description='Measure LLM API latencies')
parser.add_argument('-n', '--num_replicates', type=int, default=5,
help='Number of replicate calls per model')
parser.add_argument('-o', '--output', type=str, default='latency_results.json',
help='Output file for results')
args = parser.parse_args()
console.print(Panel.fit(
f"[yellow]Starting latency test with {args.num_replicates} replicates per model...[/yellow]",
border_style="green"
))
results, all_latencies = run_latency_test(args.num_replicates)
# Create results table
table = Table(title="🚀 Latency Test Results")
table.add_column("Model", style="cyan")
table.add_column("Mean (s)", justify="right")
table.add_column("Std Dev", justify="right")
table.add_column("Min (s)", justify="right")
table.add_column("Max (s)", justify="right")
table.add_column("Success Rate", justify="right")
# Find fastest model
fastest_model = min(
((model, stats["mean_latency"]) for model, stats in results.items() if "error" not in stats),
key=lambda x: x[1],
default=(None, float('inf'))
)[0]
for model, stats in results.items():
if "error" in stats:
table.add_row(
model,
"[red]Failed[/red]",
"[red]Failed[/red]",
"[red]Failed[/red]",
"[red]Failed[/red]",
"[red]0%[/red]"
)
else:
success_rate = f"{(stats['successful_calls']/stats['total_calls'])*100:.0f}%"
model_name = f"{model} ⚡" if model == fastest_model else model
table.add_row(
model_name,
f"{stats['mean_latency']:.3f}",
f"{stats['std_latency']:.3f}",
f"{stats['min_latency']:.3f}",
f"{stats['max_latency']:.3f}",
f"{'✅' if success_rate == '100%' else '⚠️'} {success_rate}"
)
console.print(table)
# Create box plot
console.print("\n[bold cyan]Latency Distribution Across Models:[/bold cyan]")
plt.clear_data()
plt.clear_figure()
plt.clf()
plt.plotsize(120, 15)
# Collect all latencies
model_latencies = []
model_names = []
for model, stats in results.items():
if "error" not in stats:
sorted_latencies = sorted(all_latencies[model])
model_latencies.append(sorted_latencies)
model_names.append(model)
# Create the box plot
plt.box(model_names, model_latencies, width=0.6)
plt.title("Latency Distribution by Model")
plt.ylabel("Time (s)")
plt.theme("clear")
plt.show()
# Save results
output_path = os.path.join(os.path.dirname(__file__), args.output)
with open(output_path, 'w') as f:
json.dump(results, f, indent=2)
console.print(f"\n[green]Results saved to {output_path}[/green]")
if __name__ == "__main__":
main()
@jeffmylife
Copy link
Author

Screenshot 2025-01-02 at 1 14 24 PM

With 10

Screenshot 2025-01-02 at 1 14 36 PM

With 20

Screenshot 2025-01-02 at 1 16 34 PM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment