Created
January 2, 2025 21:13
-
-
Save jeffmylife/648b8758a31c2283895a14441b311803 to your computer and use it in GitHub Desktop.
Latency comparison command line tool between LLM's
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import json | |
import os | |
import statistics | |
import time | |
from statistics import mean, stdev | |
from typing import List, Dict | |
import plotext as plt | |
from litellm import completion | |
from rich.console import Console | |
from rich.panel import Panel | |
from rich.progress import Progress, SpinnerColumn, TextColumn | |
from rich.table import Table | |
console = Console() | |
# Configure models to test | |
MODELS = [ | |
# OpenAI | |
"gpt-3.5-turbo", | |
"gpt-4o-mini", | |
# Anthropic | |
"claude-3-haiku-20240307", | |
"claude-3-5-haiku-20241022", | |
# Gemini | |
"gemini/gemini-1.5-flash-latest", | |
"gemini/gemini-1.5-flash-8b", | |
"gemini/gemini-2.0-flash-exp", | |
] | |
def measure_latency(model: str, n_calls: int = 5) -> List[float]: | |
"""Measure response latency for a given model""" | |
latencies = [] | |
messages = [{"role": "user", "content": "Respond with just one word: hello"}] | |
# Suppress litellm provider list messages | |
os.environ['LITELLM_SUPPRESS_PROVIDER_LIST'] = 'true' | |
with Progress( | |
SpinnerColumn(), | |
TextColumn("[progress.description]{task.description}"), | |
transient=True, | |
) as progress: | |
task = progress.add_task(f"Testing {model}...", total=n_calls) | |
for _ in range(n_calls): | |
try: | |
start = time.time() | |
completion(model=model, messages=messages) | |
end = time.time() | |
latency = end - start | |
latencies.append(latency) | |
except Exception as e: | |
console.print(f"[red]Error with model {model}: {str(e)}[/red]") | |
continue | |
finally: | |
progress.advance(task) | |
return latencies | |
def run_latency_test(n_replicates: int = 5) -> Dict: | |
"""Run latency tests across all models""" | |
results = {} | |
all_latencies = {} | |
for model in MODELS: | |
console.print(Panel(f"Testing {model} 🔄", style="blue")) | |
latencies = measure_latency(model, n_replicates) | |
all_latencies[model] = latencies | |
if latencies: | |
results[model] = { | |
"mean_latency": round(mean(latencies), 3), | |
"std_latency": round(stdev(latencies), 3) if len(latencies) > 1 else 0, | |
"min_latency": round(min(latencies), 3), | |
"max_latency": round(max(latencies), 3), | |
"successful_calls": len(latencies), | |
"total_calls": n_replicates | |
} | |
else: | |
results[model] = { | |
"error": "All calls failed" | |
} | |
return results, all_latencies | |
def main(): | |
parser = argparse.ArgumentParser(description='Measure LLM API latencies') | |
parser.add_argument('-n', '--num_replicates', type=int, default=5, | |
help='Number of replicate calls per model') | |
parser.add_argument('-o', '--output', type=str, default='latency_results.json', | |
help='Output file for results') | |
args = parser.parse_args() | |
console.print(Panel.fit( | |
f"[yellow]Starting latency test with {args.num_replicates} replicates per model...[/yellow]", | |
border_style="green" | |
)) | |
results, all_latencies = run_latency_test(args.num_replicates) | |
# Create results table | |
table = Table(title="🚀 Latency Test Results") | |
table.add_column("Model", style="cyan") | |
table.add_column("Mean (s)", justify="right") | |
table.add_column("Std Dev", justify="right") | |
table.add_column("Min (s)", justify="right") | |
table.add_column("Max (s)", justify="right") | |
table.add_column("Success Rate", justify="right") | |
# Find fastest model | |
fastest_model = min( | |
((model, stats["mean_latency"]) for model, stats in results.items() if "error" not in stats), | |
key=lambda x: x[1], | |
default=(None, float('inf')) | |
)[0] | |
for model, stats in results.items(): | |
if "error" in stats: | |
table.add_row( | |
model, | |
"[red]Failed[/red]", | |
"[red]Failed[/red]", | |
"[red]Failed[/red]", | |
"[red]Failed[/red]", | |
"[red]0%[/red]" | |
) | |
else: | |
success_rate = f"{(stats['successful_calls']/stats['total_calls'])*100:.0f}%" | |
model_name = f"{model} ⚡" if model == fastest_model else model | |
table.add_row( | |
model_name, | |
f"{stats['mean_latency']:.3f}", | |
f"{stats['std_latency']:.3f}", | |
f"{stats['min_latency']:.3f}", | |
f"{stats['max_latency']:.3f}", | |
f"{'✅' if success_rate == '100%' else '⚠️'} {success_rate}" | |
) | |
console.print(table) | |
# Create box plot | |
console.print("\n[bold cyan]Latency Distribution Across Models:[/bold cyan]") | |
plt.clear_data() | |
plt.clear_figure() | |
plt.clf() | |
plt.plotsize(120, 15) | |
# Collect all latencies | |
model_latencies = [] | |
model_names = [] | |
for model, stats in results.items(): | |
if "error" not in stats: | |
sorted_latencies = sorted(all_latencies[model]) | |
model_latencies.append(sorted_latencies) | |
model_names.append(model) | |
# Create the box plot | |
plt.box(model_names, model_latencies, width=0.6) | |
plt.title("Latency Distribution by Model") | |
plt.ylabel("Time (s)") | |
plt.theme("clear") | |
plt.show() | |
# Save results | |
output_path = os.path.join(os.path.dirname(__file__), args.output) | |
with open(output_path, 'w') as f: | |
json.dump(results, f, indent=2) | |
console.print(f"\n[green]Results saved to {output_path}[/green]") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
With 10
With 20