Created
April 26, 2026 10:57
-
-
Save yangsheng6810/3d81fec651c93df1293bc92e21244e55 to your computer and use it in GitHub Desktop.
benchmark local llm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Benchmark tool for local LLM (OpenAI-compatible API, e.g., Kimi-K2.6) | |
| Measures: | |
| - Time to First Token (TTFT) | |
| - End-to-end latency | |
| - Completion tokens per second (generation speed) | |
| - Throughput (requests per second) | |
| - Success rate | |
| - Percentiles (p50, p95, p99) | |
| Supports chat and code generation tasks. | |
| """ | |
| import asyncio | |
| import aiohttp | |
| import time | |
| import argparse | |
| import json | |
| import sys | |
| from typing import List, Dict, Any, Optional | |
| from statistics import mean, median, stdev | |
| from collections import defaultdict | |
| import tqdm.asyncio | |
| # ---------- Default prompts for different tasks ---------- | |
| DEFAULT_PROMPTS = { | |
| "chat_short": "Hello! How are you today?", | |
| "chat_long": "Explain the theory of relativity in simple terms, including both special and general relativity, with examples.", | |
| "vibe_coding": "Write a Python function that takes a list of integers and returns a new list containing only the prime numbers, along with a short explanation of the algorithm.", | |
| "code_debug": "What's wrong with this Python code? It's supposed to compute factorial:\n\ndef factorial(n):\n if n == 0:\n return 1\n else:\n return n * factorial(n-2)\n\nPlease explain the bug and provide the corrected version.", | |
| } | |
| # ---------- Helper: parse arguments ---------- | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Benchmark local LLM performance") | |
| parser.add_argument("--url", type=str, default="http://localhost:8000/v1/chat/completions", | |
| help="OpenAI-compatible chat completions endpoint") | |
| parser.add_argument("--model", type=str, default="Kimi-K2.6", help="Model name to use") | |
| parser.add_argument("--api-key", type=str, default="", help="API key if required") | |
| parser.add_argument("--concurrency", type=int, default=1, help="Number of concurrent requests") | |
| parser.add_argument("--num-requests", type=int, default=10, help="Total number of requests to send") | |
| parser.add_argument("--timeout", type=int, default=120, help="Timeout per request in seconds") | |
| parser.add_argument("--stream", action="store_true", default=True, | |
| help="Use streaming to measure TTFT and generation speed (recommended)") | |
| parser.add_argument("--no-stream", dest="stream", action="store_false", help="Disable streaming") | |
| parser.add_argument("--prompt", type=str, help="Custom prompt (overrides --mode and --prompt-file)") | |
| parser.add_argument("--prompt-file", type=str, help="File containing a list of prompts (one per line) to cycle through") | |
| parser.add_argument("--mode", type=str, choices=["chat_short", "chat_long", "vibe_coding", "code_debug"], | |
| default="chat_long", help="Predefined task type (ignored if --prompt or --prompt-file given)") | |
| parser.add_argument("--max-tokens", type=int, default=512, help="Maximum tokens to generate") | |
| parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature") | |
| parser.add_argument("--warmup-requests", type=int, default=1, help="Number of warmup requests (not counted in results)") | |
| parser.add_argument("--output-json", type=str, help="Save detailed results to JSON file") | |
| return parser.parse_args() | |
| # ---------- Load prompts from file or use single prompt ---------- | |
| def load_prompts(args) -> List[str]: | |
| if args.prompt: | |
| return [args.prompt] | |
| if args.prompt_file: | |
| with open(args.prompt_file, "r") as f: | |
| prompts = [line.strip() for line in f if line.strip()] | |
| if not prompts: | |
| raise ValueError("Prompt file is empty") | |
| return prompts | |
| # Use predefined mode | |
| if args.mode not in DEFAULT_PROMPTS: | |
| raise ValueError(f"Unknown mode: {args.mode}") | |
| return [DEFAULT_PROMPTS[args.mode]] | |
| # ---------- Async request function (streaming or non-streaming) ---------- | |
| async def send_request(session, url, headers, payload, request_id, sem, stream_mode, timeout): | |
| """ | |
| Perform a single request and collect metrics. | |
| Returns a dict with: | |
| - success: bool | |
| - ttft_ms: float (time to first token, if streaming) | |
| - total_time_ms: float | |
| - completion_tokens: int (if available) | |
| - tokens_per_sec: float (completion_tokens / generation_time) | |
| - error: str (if failed) | |
| """ | |
| async with sem: | |
| start_time = time.perf_counter() | |
| first_token_time = None | |
| last_token_time = None | |
| completion_tokens = 0 | |
| total_tokens = 0 | |
| error_msg = None | |
| response_text = "" | |
| try: | |
| async with session.post(url, json=payload, headers=headers, timeout=aiohttp.ClientTimeout(total=timeout)) as resp: | |
| if resp.status != 200: | |
| error_msg = f"HTTP {resp.status}: {await resp.text()}" | |
| return { | |
| "request_id": request_id, | |
| "success": False, | |
| "error": error_msg, | |
| "total_time_ms": (time.perf_counter() - start_time) * 1000, | |
| "ttft_ms": None, | |
| "tokens_per_sec": None, | |
| "completion_tokens": 0, | |
| } | |
| if stream_mode: | |
| # Process streaming response line by line | |
| async for line in resp.content: | |
| line = line.decode('utf-8').strip() | |
| if not line.startswith("data: "): | |
| continue | |
| data_str = line[6:] # remove "data: " | |
| if data_str == "[DONE]": | |
| break | |
| try: | |
| chunk = json.loads(data_str) | |
| except json.JSONDecodeError: | |
| continue | |
| # Record time to first token | |
| if first_token_time is None: | |
| first_token_time = time.perf_counter() | |
| # Try to extract content and usage (some servers include usage in final chunk) | |
| if "usage" in chunk: | |
| usage = chunk["usage"] | |
| completion_tokens = usage.get("completion_tokens", 0) | |
| total_tokens = usage.get("total_tokens", 0) | |
| else: | |
| # Fallback: count tokens roughly (if needed) | |
| choices = chunk.get("choices", []) | |
| if choices and "delta" in choices[0]: | |
| delta = choices[0]["delta"] | |
| if "content" in delta: | |
| response_text += delta["content"] | |
| last_token_time = time.perf_counter() | |
| # If we never got a usage field, try to estimate tokens from characters (very rough) | |
| if completion_tokens == 0 and response_text: | |
| # Approx: 4 chars per token for English | |
| completion_tokens = len(response_text) // 4 | |
| total_tokens = payload.get("max_tokens", 0) # placeholder | |
| else: | |
| # Non-streaming: get full response and usage | |
| data = await resp.json() | |
| if "usage" in data: | |
| completion_tokens = data["usage"].get("completion_tokens", 0) | |
| total_tokens = data["usage"].get("total_tokens", 0) | |
| else: | |
| # Fallback: count choices text | |
| choices = data.get("choices", []) | |
| if choices and "message" in choices[0]: | |
| content = choices[0]["message"].get("content", "") | |
| completion_tokens = len(content) // 4 | |
| # For non-streaming, TTFT is not applicable (set to total time) | |
| first_token_time = time.perf_counter() | |
| last_token_time = first_token_time | |
| except asyncio.TimeoutError: | |
| error_msg = f"Timeout after {timeout}s" | |
| except Exception as e: | |
| error_msg = str(e) | |
| end_time = time.perf_counter() | |
| total_time_ms = (end_time - start_time) * 1000 | |
| if error_msg: | |
| return { | |
| "request_id": request_id, | |
| "success": False, | |
| "error": error_msg, | |
| "total_time_ms": total_time_ms, | |
| "ttft_ms": None, | |
| "tokens_per_sec": None, | |
| "completion_tokens": 0, | |
| } | |
| ttft_ms = (first_token_time - start_time) * 1000 if first_token_time else None | |
| generation_time_ms = (last_token_time - first_token_time) * 1000 if stream_mode and first_token_time and last_token_time else None | |
| if generation_time_ms and generation_time_ms > 0 and completion_tokens > 0: | |
| tps = completion_tokens / (generation_time_ms / 1000.0) | |
| else: | |
| tps = None | |
| return { | |
| "request_id": request_id, | |
| "success": True, | |
| "total_time_ms": total_time_ms, | |
| "ttft_ms": ttft_ms, | |
| "tokens_per_sec": tps, | |
| "completion_tokens": completion_tokens, | |
| "error": None, | |
| } | |
| # ---------- Run benchmark ---------- | |
| async def run_benchmark(args): | |
| prompts = load_prompts(args) | |
| # Repeat prompts if needed to cover num_requests | |
| if len(prompts) < args.num_requests: | |
| prompts = prompts * (args.num_requests // len(prompts) + 1) | |
| prompts = prompts[:args.num_requests] | |
| headers = { | |
| "Content-Type": "application/json", | |
| } | |
| if args.api_key: | |
| headers["Authorization"] = f"Bearer {args.api_key}" | |
| sem = asyncio.Semaphore(args.concurrency) | |
| async with aiohttp.ClientSession() as session: | |
| # Warmup | |
| if args.warmup_requests > 0: | |
| print(f"Warming up with {args.warmup_requests} request(s)...") | |
| warmup_prompts = prompts[:args.warmup_requests] | |
| tasks = [] | |
| for i, prompt in enumerate(warmup_prompts): | |
| payload = { | |
| "model": args.model, | |
| "messages": [{"role": "user", "content": prompt}], | |
| "stream": args.stream, | |
| "max_tokens": args.max_tokens, | |
| "temperature": args.temperature, | |
| } | |
| task = send_request(session, args.url, headers, payload, i, sem, args.stream, args.timeout) | |
| tasks.append(task) | |
| await asyncio.gather(*tasks) | |
| print("Warmup done.") | |
| # Actual benchmark | |
| print(f"Starting benchmark: {args.num_requests} requests, concurrency={args.concurrency}, stream={args.stream}") | |
| tasks = [] | |
| for i in range(args.num_requests): | |
| prompt = prompts[i % len(prompts)] | |
| payload = { | |
| "model": args.model, | |
| "messages": [{"role": "user", "content": prompt}], | |
| "stream": args.stream, | |
| "max_tokens": args.max_tokens, | |
| "temperature": args.temperature, | |
| } | |
| tasks.append(send_request(session, args.url, headers, payload, i, sem, args.stream, args.timeout)) | |
| # Use tqdm to show progress | |
| results = [] | |
| for coro in tqdm.asyncio.tqdm.as_completed(tasks, desc="Requests", total=args.num_requests): | |
| res = await coro | |
| results.append(res) | |
| # Post-process | |
| successful = [r for r in results if r["success"]] | |
| failed = [r for r in results if not r["success"]] | |
| # Compute statistics | |
| total_time = sum(r["total_time_ms"] for r in successful) / 1000.0 if successful else 0 | |
| total_completion_tokens = sum(r["completion_tokens"] for r in successful) | |
| total_requests = len(results) | |
| success_rate = len(successful) / total_requests * 100 if total_requests else 0 | |
| # Latency percentiles (total time) | |
| latencies = [r["total_time_ms"] for r in successful] | |
| ttfts = [r["ttft_ms"] for r in successful if r["ttft_ms"] is not None] | |
| tps_values = [r["tokens_per_sec"] for r in successful if r["tokens_per_sec"] is not None] | |
| stats = { | |
| "total_requests": total_requests, | |
| "successful": len(successful), | |
| "failed": len(failed), | |
| "success_rate_%": success_rate, | |
| "concurrency": args.concurrency, | |
| "streaming": args.stream, | |
| "total_duration_sec": total_time if total_time else 0, | |
| "throughput_req_per_sec": len(successful) / total_time if total_time > 0 else 0, | |
| "total_completion_tokens": total_completion_tokens, | |
| "avg_total_latency_ms": mean(latencies) if latencies else None, | |
| "min_total_latency_ms": min(latencies) if latencies else None, | |
| "max_total_latency_ms": max(latencies) if latencies else None, | |
| "p50_total_latency_ms": percentile(latencies, 50) if latencies else None, | |
| "p95_total_latency_ms": percentile(latencies, 95) if latencies else None, | |
| "p99_total_latency_ms": percentile(latencies, 99) if latencies else None, | |
| "avg_ttft_ms": mean(ttfts) if ttfts else None, | |
| "p95_ttft_ms": percentile(ttfts, 95) if ttfts else None, | |
| "avg_tokens_per_sec": mean(tps_values) if tps_values else None, | |
| "min_tokens_per_sec": min(tps_values) if tps_values else None, | |
| "max_tokens_per_sec": max(tps_values) if tps_values else None, | |
| } | |
| # Print results | |
| print("\n" + "=" * 60) | |
| print("Benchmark Results") | |
| print("=" * 60) | |
| print(f"Requests: {stats['total_requests']} (successful: {stats['successful']}, failed: {stats['failed']})") | |
| print(f"Success rate: {stats['success_rate_%']:.2f}%") | |
| print(f"Concurrency: {stats['concurrency']}") | |
| print(f"Total test duration: {stats['total_duration_sec']:.2f} s") | |
| print(f"Throughput: {stats['throughput_req_per_sec']:.2f} req/s") | |
| print(f"Total completion tokens generated: {stats['total_completion_tokens']}") | |
| print() | |
| print("--- Latency (end-to-end) ---") | |
| if stats['avg_total_latency_ms']: | |
| print(f" Avg: {stats['avg_total_latency_ms']:.2f} ms") | |
| print(f" Min: {stats['min_total_latency_ms']:.2f} ms") | |
| print(f" Max: {stats['max_total_latency_ms']:.2f} ms") | |
| print(f" P50: {stats['p50_total_latency_ms']:.2f} ms") | |
| print(f" P95: {stats['p95_total_latency_ms']:.2f} ms") | |
| print(f" P99: {stats['p99_total_latency_ms']:.2f} ms") | |
| else: | |
| print(" (no successful requests)") | |
| print() | |
| if stats['streaming'] and stats['avg_ttft_ms'] is not None: | |
| print("--- Time to First Token (TTFT) ---") | |
| print(f" Avg: {stats['avg_ttft_ms']:.2f} ms") | |
| print(f" P95: {stats['p95_ttft_ms']:.2f} ms") | |
| print() | |
| if stats['avg_tokens_per_sec'] is not None: | |
| print("--- Generation Speed (tokens/s) ---") | |
| print(f" Avg: {stats['avg_tokens_per_sec']:.2f} tok/s") | |
| print(f" Min: {stats['min_tokens_per_sec']:.2f} tok/s") | |
| print(f" Max: {stats['max_tokens_per_sec']:.2f} tok/s") | |
| else: | |
| print("--- Generation Speed: not available (enable stream and ensure usage field is returned)") | |
| print("=" * 60) | |
| # Save detailed results if requested | |
| if args.output_json: | |
| output = { | |
| "config": { | |
| "url": args.url, | |
| "model": args.model, | |
| "concurrency": args.concurrency, | |
| "num_requests": args.num_requests, | |
| "stream": args.stream, | |
| "max_tokens": args.max_tokens, | |
| "temperature": args.temperature, | |
| "mode": args.mode, | |
| }, | |
| "stats": stats, | |
| "details": results, | |
| } | |
| with open(args.output_json, "w") as f: | |
| json.dump(output, f, indent=2) | |
| print(f"\nDetailed results saved to {args.output_json}") | |
| def percentile(data, p): | |
| """Calculate p-th percentile (0-100) of a list of numbers.""" | |
| if not data: | |
| return None | |
| data_sorted = sorted(data) | |
| k = (len(data_sorted) - 1) * p / 100.0 | |
| f = int(k) | |
| c = k - f | |
| if f + 1 < len(data_sorted): | |
| return data_sorted[f] + c * (data_sorted[f+1] - data_sorted[f]) | |
| else: | |
| return data_sorted[f] | |
| def main(): | |
| args = parse_args() | |
| asyncio.run(run_benchmark(args)) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment