Created
August 26, 2025 21:26
-
-
Save bwasti/00ec853a4125b600f2d4506bad03f7c9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
OpenAI Prediction API Benchmark Tool | |
Benchmarks latency and throughput for the OpenAI Completions API with prediction functionality. | |
Supports custom endpoints (e.g., localhost:8000) for testing vLLM implementations. | |
""" | |
import asyncio | |
import time | |
import statistics | |
import json | |
import argparse | |
from typing import List, Dict, Any, Tuple | |
from dataclasses import dataclass, asdict | |
from concurrent.futures import ThreadPoolExecutor | |
import aiohttp | |
import requests | |
from datetime import datetime | |
import requests | |
@dataclass | |
class BenchmarkConfig: | |
"""Configuration for benchmark runs""" | |
base_url: str = "http://localhost:8000" | |
api_key: str = "test-key" | |
model: str = "gpt-3.5-turbo" | |
concurrent_requests: int = 1 | |
total_requests: int = 10 | |
timeout: int = 60 | |
use_prediction: bool = True | |
prediction_content: str = "The waves crash against the shore" | |
log_outputs: bool = False | |
log_file: str = "benchmark_outputs.log" | |
temperature: float = 0.7 | |
max_tokens: int = 150 | |
@dataclass | |
class RequestResult: | |
"""Results from a single API request""" | |
latency: float | |
tokens_generated: int | |
success: bool | |
error_message: str = "" | |
prediction_accuracy: float = 0.0 | |
total_tokens: int = 0 | |
prompt_tokens: int = 0 | |
completion_tokens: int = 0 | |
response_content: str = "" | |
request_id: str = "" | |
accepted_prediction_tokens: int = 0 | |
rejected_prediction_tokens: int = 0 | |
@dataclass | |
class BenchmarkResults: | |
"""Aggregated benchmark results""" | |
total_requests: int | |
successful_requests: int | |
failed_requests: int | |
avg_latency: float | |
p50_latency: float | |
p95_latency: float | |
p99_latency: float | |
min_latency: float | |
max_latency: float | |
total_tokens: int | |
avg_tokens_per_second: float | |
requests_per_second: float | |
total_duration: float | |
error_rate: float | |
class OpenAIPredictionBenchmark: | |
"""Benchmark tool for OpenAI Prediction API""" | |
def __init__(self, config: BenchmarkConfig): | |
self.config = config | |
self.results: List[RequestResult] = [] | |
self.request_counter = 0 | |
# Set up logging if enabled | |
if self.config.log_outputs: | |
import logging | |
# Create a custom logger for outputs | |
self.output_logger = logging.getLogger('benchmark_outputs') | |
self.output_logger.setLevel(logging.INFO) | |
# Clear any existing handlers to avoid duplicates | |
self.output_logger.handlers = [] | |
# Create file handler | |
file_handler = logging.FileHandler(self.config.log_file, mode='w', encoding='utf-8') | |
file_handler.setLevel(logging.INFO) | |
# Create formatter (without request_id in the format string) | |
formatter = logging.Formatter( | |
'%(asctime)s - %(levelname)s - %(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S' | |
) | |
file_handler.setFormatter(formatter) | |
# Add handler to logger | |
self.output_logger.addHandler(file_handler) | |
self.output_logger.info("="*80) | |
self.output_logger.info("BENCHMARK OUTPUT LOGGING STARTED") | |
self.output_logger.info(f"Configuration: {asdict(self.config)}") | |
self.output_logger.info("="*80) | |
else: | |
self.output_logger = None | |
def _log_request_response(self, request_id: int, prompt: str, result: RequestResult, raw_response: Dict[str, Any] = None): | |
print('CONTENT', result.completion_tokens, result.response_content) | |
"""Log request and response details if logging is enabled""" | |
if not self.output_logger: | |
return | |
# Log request details | |
self.output_logger.info(f"REQUEST #{request_id} DETAILS") | |
self.output_logger.info(f"REQUEST #{request_id} - Prompt: {prompt[:200]}{'...' if len(prompt) > 200 else ''}") | |
if self.config.use_prediction: | |
self.output_logger.info(f"REQUEST #{request_id} - Prediction Content: {self.config.prediction_content[:200]}{'...' if len(self.config.prediction_content) > 200 else ''}") | |
# Log response details | |
self.output_logger.info(f"RESPONSE #{request_id} DETAILS") | |
self.output_logger.info(f"RESPONSE #{request_id} - Success: {result.success}") | |
if result.success: | |
self.output_logger.info(f"RESPONSE #{request_id} - Latency: {result.latency:.4f}s") | |
self.output_logger.info(f"RESPONSE #{request_id} - Tokens - Total: {result.total_tokens}, Prompt: {result.prompt_tokens}, Completion: {result.completion_tokens}") | |
if result.accepted_prediction_tokens > 0 or result.rejected_prediction_tokens > 0: | |
self.output_logger.info(f"RESPONSE #{request_id} - Prediction Tokens - Accepted: {result.accepted_prediction_tokens}, Rejected: {result.rejected_prediction_tokens}") | |
self.output_logger.info(f"RESPONSE #{request_id} - Content: {result.response_content[:500]}{'...' if len(result.response_content) > 500 else ''}") | |
# Log raw response if available (for debugging) | |
if raw_response: | |
self.output_logger.info(f"RESPONSE #{request_id} - Raw Response: {json.dumps(raw_response, indent=2)[:1000]}{'...' if len(json.dumps(raw_response, indent=2)) > 1000 else ''}") | |
else: | |
self.output_logger.info(f"RESPONSE #{request_id} - Error: {result.error_message}") | |
self.output_logger.info("-" * 80) | |
def _prepare_request_payload(self, prompt: str, use_prediction: bool = True) -> Dict[str, Any]: | |
"""Prepare the request payload for the API call""" | |
payload = { | |
"model": self.config.model, | |
"messages": [ | |
{"role": "user", "content": prompt} | |
], | |
"max_tokens": self.config.max_tokens, | |
"temperature": self.config.temperature | |
} | |
if use_prediction and self.config.use_prediction: | |
payload["prediction"] = { | |
"type": "content", | |
"content": self.config.prediction_content.strip() | |
} | |
return payload | |
async def _make_async_request(self, session: aiohttp.ClientSession, prompt: str) -> RequestResult: | |
"""Make an async request to the API""" | |
self.request_counter += 1 | |
request_id = self.request_counter | |
start_time = time.time() | |
try: | |
payload = self._prepare_request_payload(prompt) | |
headers = { | |
"Authorization": f"Bearer {self.config.api_key}", | |
"Content-Type": "application/json" | |
} | |
async with session.post( | |
f"{self.config.base_url}/v1/chat/completions", | |
json=payload, | |
headers=headers, | |
timeout=aiohttp.ClientTimeout(total=self.config.timeout) | |
) as response: | |
end_time = time.time() | |
latency = end_time - start_time | |
if response.status == 200: | |
data = await response.json() | |
usage = data.get("usage", {}) | |
completion_tokens_details = usage.get("completion_tokens_details", {}) | |
# Extract response content | |
response_content = "" | |
if data.get("choices") and len(data["choices"]) > 0: | |
message = data["choices"][0].get("message", {}) | |
response_content = message.get("content", "") | |
result = RequestResult( | |
latency=latency, | |
tokens_generated=usage.get("completion_tokens", 0), | |
success=True, | |
total_tokens=usage.get("total_tokens", 0), | |
prompt_tokens=usage.get("prompt_tokens", 0), | |
completion_tokens=usage.get("completion_tokens", 0), | |
response_content=response_content, | |
request_id=str(request_id), | |
accepted_prediction_tokens=completion_tokens_details.get("accepted_prediction_tokens", 0), | |
rejected_prediction_tokens=completion_tokens_details.get("rejected_prediction_tokens", 0) | |
) | |
# Log the request and response | |
self._log_request_response(request_id, prompt, result, data) | |
return result | |
else: | |
error_text = await response.text() | |
result = RequestResult( | |
latency=latency, | |
tokens_generated=0, | |
success=False, | |
error_message=f"HTTP {response.status}: {error_text}", | |
request_id=str(request_id) | |
) | |
# Log the error | |
self._log_request_response(request_id, prompt, result) | |
return result | |
except Exception as e: | |
end_time = time.time() | |
result = RequestResult( | |
latency=end_time - start_time, | |
tokens_generated=0, | |
success=False, | |
error_message=str(e), | |
request_id=str(request_id) | |
) | |
# Log the exception | |
self._log_request_response(request_id, prompt, result) | |
return result | |
def _make_sync_request(self, prompt: str) -> RequestResult: | |
"""Make a synchronous request to the API""" | |
self.request_counter += 1 | |
request_id = self.request_counter | |
start_time = time.time() | |
try: | |
payload = self._prepare_request_payload(prompt) | |
headers = { | |
"Authorization": f"Bearer {self.config.api_key}", | |
"Content-Type": "application/json" | |
} | |
response = requests.post( | |
f"{self.config.base_url}/v1/chat/completions", | |
json=payload, | |
headers=headers, | |
timeout=self.config.timeout | |
) | |
end_time = time.time() | |
latency = end_time - start_time | |
if response.status_code == 200: | |
data = response.json() | |
usage = data.get("usage", {}) | |
completion_tokens_details = usage.get("completion_tokens_details", {}) | |
# Extract response content | |
response_content = "" | |
if data.get("choices") and len(data["choices"]) > 0: | |
message = data["choices"][0].get("message", {}) | |
response_content = message.get("content", "") | |
result = RequestResult( | |
latency=latency, | |
tokens_generated=usage.get("completion_tokens", 0), | |
success=True, | |
total_tokens=usage.get("total_tokens", 0), | |
prompt_tokens=usage.get("prompt_tokens", 0), | |
completion_tokens=usage.get("completion_tokens", 0), | |
response_content=response_content, | |
request_id=str(request_id), | |
accepted_prediction_tokens=completion_tokens_details.get("accepted_prediction_tokens", 0), | |
rejected_prediction_tokens=completion_tokens_details.get("rejected_prediction_tokens", 0) | |
) | |
# Log the request and response | |
self._log_request_response(request_id, prompt, result, data) | |
return result | |
else: | |
result = RequestResult( | |
latency=latency, | |
tokens_generated=0, | |
success=False, | |
error_message=f"HTTP {response.status_code}: {response.text}", | |
request_id=str(request_id) | |
) | |
# Log the error | |
self._log_request_response(request_id, prompt, result) | |
return result | |
except Exception as e: | |
end_time = time.time() | |
result = RequestResult( | |
latency=end_time - start_time, | |
tokens_generated=0, | |
success=False, | |
error_message=str(e), | |
request_id=str(request_id) | |
) | |
# Log the exception | |
self._log_request_response(request_id, prompt, result) | |
return result | |
async def run_async_benchmark(self, prompts: List[str]) -> BenchmarkResults: | |
"""Run async benchmark with concurrent requests""" | |
print(f"Running async benchmark with {self.config.concurrent_requests} concurrent requests...") | |
start_time = time.time() | |
connector = aiohttp.TCPConnector(limit=self.config.concurrent_requests * 2) | |
async with aiohttp.ClientSession(connector=connector) as session: | |
semaphore = asyncio.Semaphore(self.config.concurrent_requests) | |
async def bounded_request(prompt): | |
async with semaphore: | |
return await self._make_async_request(session, prompt) | |
tasks = [bounded_request(prompt) for prompt in prompts] | |
results = await asyncio.gather(*tasks) | |
end_time = time.time() | |
total_duration = end_time - start_time | |
return self._calculate_metrics(results, total_duration) | |
def run_sync_benchmark(self, prompts: List[str]) -> BenchmarkResults: | |
"""Run synchronous benchmark""" | |
print(f"Running sync benchmark with {self.config.concurrent_requests} concurrent requests...") | |
start_time = time.time() | |
if self.config.concurrent_requests == 1: | |
results = [self._make_sync_request(prompt) for prompt in prompts] | |
else: | |
with ThreadPoolExecutor(max_workers=self.config.concurrent_requests) as executor: | |
results = list(executor.map(self._make_sync_request, prompts)) | |
end_time = time.time() | |
total_duration = end_time - start_time | |
return self._calculate_metrics(results, total_duration) | |
def _calculate_metrics(self, results: List[RequestResult], total_duration: float) -> BenchmarkResults: | |
"""Calculate benchmark metrics from results""" | |
# Store results for error reporting | |
self.results = results | |
successful_results = [r for r in results if r.success] | |
failed_results = [r for r in results if not r.success] | |
if not successful_results: | |
latencies = [r.latency for r in results] | |
else: | |
latencies = [r.latency for r in successful_results] | |
total_tokens = sum(r.total_tokens for r in successful_results) | |
total_completion_tokens = sum(r.completion_tokens for r in successful_results) | |
return BenchmarkResults( | |
total_requests=len(results), | |
successful_requests=len(successful_results), | |
failed_requests=len(failed_results), | |
avg_latency=statistics.mean(latencies) if latencies else 0, | |
p50_latency=statistics.median(latencies) if latencies else 0, | |
p95_latency=self._percentile(latencies, 0.95) if latencies else 0, | |
p99_latency=self._percentile(latencies, 0.99) if latencies else 0, | |
min_latency=min(latencies) if latencies else 0, | |
max_latency=max(latencies) if latencies else 0, | |
total_tokens=total_tokens, | |
avg_tokens_per_second=total_completion_tokens / total_duration if total_duration > 0 else 0, | |
requests_per_second=len(successful_results) / total_duration if total_duration > 0 else 0, | |
total_duration=total_duration, | |
error_rate=len(failed_results) / len(results) if results else 0 | |
) | |
@staticmethod | |
def _percentile(data: List[float], percentile: float) -> float: | |
"""Calculate percentile from data""" | |
if not data: | |
return 0 | |
sorted_data = sorted(data) | |
index = int(percentile * len(sorted_data)) | |
if index >= len(sorted_data): | |
return sorted_data[-1] | |
return sorted_data[index] | |
def print_results(self, results: BenchmarkResults, show_errors: bool = False): | |
"""Print benchmark results in a formatted way""" | |
print("\n" + "="*60) | |
print("BENCHMARK RESULTS") | |
print("="*60) | |
print(f"Timestamp: {datetime.now().isoformat()}") | |
print(f"Configuration:") | |
print(f" Endpoint: {self.config.base_url}") | |
print(f" Model: {self.config.model}") | |
print(f" Concurrent Requests: {self.config.concurrent_requests}") | |
print(f" Total Requests: {self.config.total_requests}") | |
print(f" Use Prediction: {self.config.use_prediction}") | |
print(f" Temperature: {self.config.temperature}") | |
print(f" Max Tokens: {self.config.max_tokens}") | |
print(f"\nRequest Statistics:") | |
print(f" Total Requests: {results.total_requests}") | |
print(f" Successful: {results.successful_requests}") | |
print(f" Failed: {results.failed_requests}") | |
print(f" Error Rate: {results.error_rate:.2%}") | |
# Show error details if requested and there are failures | |
if show_errors and results.failed_requests > 0: | |
failed_results = [r for r in self.results if not r.success] | |
print(f"\nError Details:") | |
error_counts = {} | |
for i, result in enumerate(failed_results, 1): | |
error_msg = result.error_message | |
if error_msg in error_counts: | |
error_counts[error_msg] += 1 | |
else: | |
error_counts[error_msg] = 1 | |
for error_msg, count in error_counts.items(): | |
print(f" [{count}x] {error_msg}") | |
print(f"\nLatency Metrics (seconds):") | |
print(f" Average: {results.avg_latency:.4f}") | |
print(f" Median (P50): {results.p50_latency:.4f}") | |
print(f" P95: {results.p95_latency:.4f}") | |
print(f" P99: {results.p99_latency:.4f}") | |
print(f" Min: {results.min_latency:.4f}") | |
print(f" Max: {results.max_latency:.4f}") | |
print(f"\nThroughput Metrics:") | |
print(f" Requests per Second: {results.requests_per_second:.2f}") | |
print(f" Tokens per Second: {results.avg_tokens_per_second:.2f}") | |
print(f" Total Tokens Generated: {results.total_tokens}") | |
print(f" Total Duration: {results.total_duration:.2f}s") | |
# Show prediction token stats if available | |
if hasattr(self, 'results') and self.results: | |
total_accepted = sum(r.accepted_prediction_tokens for r in self.results if r.success) | |
total_rejected = sum(r.rejected_prediction_tokens for r in self.results if r.success) | |
if total_accepted > 0 or total_rejected > 0: | |
print(f" Prediction Tokens - Accepted: {total_accepted}, Rejected: {total_rejected}") | |
if total_accepted + total_rejected > 0: | |
acceptance_rate = total_accepted / (total_accepted + total_rejected) | |
print(f" Prediction Acceptance Rate: {acceptance_rate:.2%}") | |
print("="*60) | |
# Log completion if logging enabled | |
if self.config.log_outputs: | |
self.output_logger.info("="*80) | |
self.output_logger.info("BENCHMARK COMPLETED") | |
self.output_logger.info(f"Results summary: {asdict(results)}") | |
self.output_logger.info("="*80) | |
def save_results(self, results: BenchmarkResults, filename: str): | |
"""Save results to JSON file""" | |
output = { | |
"config": asdict(self.config), | |
"results": asdict(results), | |
"timestamp": datetime.now().isoformat() | |
} | |
with open(filename, 'w') as f: | |
json.dump(output, f, indent=2) | |
print(f"Results saved to {filename}") | |
def generate_test_prompts(count: int) -> List[str]: | |
"""Generate test prompts for benchmarking""" | |
base_prompts = [ | |
"Write a poem about the ocean.", | |
"Write a short story about a robot learning to paint.", | |
"Explain the concept of machine learning in simple terms.", | |
"What are the benefits and drawbacks of renewable energy?", | |
"Describe the process of photosynthesis.", | |
"Explain how blockchain technology works.", | |
"What are the main causes of climate change?", | |
"Describe the history of artificial intelligence.", | |
"Write a recipe for chocolate chip cookies.", | |
"Explain the theory of relativity." | |
] | |
prompts = [] | |
for i in range(count): | |
base_prompt = base_prompts[i % len(base_prompts)] | |
prompts.append(f"{base_prompt} (Request #{i+1})") | |
return prompts | |
async def main(): | |
parser = argparse.ArgumentParser(description="Benchmark OpenAI Prediction API") | |
parser.add_argument("--url", default="http://localhost:8000", help="Base URL for the API") | |
parser.add_argument("--api-key", default="test-key", help="API key") | |
parser.add_argument("--model", default="gpt-3.5-turbo", help="Model name") | |
parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests") | |
parser.add_argument("--requests", type=int, default=10, help="Total number of requests") | |
parser.add_argument("--timeout", type=int, default=60, help="Request timeout in seconds") | |
parser.add_argument("--no-prediction", action="store_true", help="Disable prediction feature") | |
parser.add_argument("--prediction-content", default="The waves crash against the shore,\nA symphony of sound and motion.\nThe salty breeze caresses my face,\nA reminder of the endless sea.\n\nThe ocean's vast expanse,\nA canvas of endless blue.\nThe horizon stretching out,\nA promise of adventure yet to come.\n\nThe sun sets, casting a golden glow,\nA reminder of the beauty that lies.\nThe moon rises, a beacon of light,\nA reminder of the endless night.\n\nThe ocean's power,\nA force to be reckoned with.\nThe tides, a constant dance,\nA testament to the ocean", | |
help="Content for prediction") | |
parser.add_argument("--sync", action="store_true", help="Use synchronous requests instead of async") | |
parser.add_argument("--profile", action="store_true", help="Profile the run (saves on server side)") | |
parser.add_argument("--output", help="Output file for results (JSON format)") | |
parser.add_argument("--show-errors", action="store_true", help="Show detailed error messages for failures") | |
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose output (alias for --show-errors)") | |
parser.add_argument("--log-outputs", action="store_true", help="Log all request/response details to file") | |
parser.add_argument("--log-file", default="benchmark_outputs.log", help="File to log outputs to (default: benchmark_outputs.log)") | |
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for text generation (default: 0.7)") | |
parser.add_argument("--max-tokens", type=int, default=150, help="Maximum tokens to generate per request (default: 150)") | |
args = parser.parse_args() | |
print(f"Using temperature: {args.temperature}") | |
print(f"Using max_tokens: {args.max_tokens}") | |
config = BenchmarkConfig( | |
base_url=args.url, | |
api_key=args.api_key, | |
model=args.model, | |
concurrent_requests=args.concurrent, | |
total_requests=args.requests, | |
timeout=args.timeout, | |
use_prediction=not args.no_prediction, | |
prediction_content=args.prediction_content, | |
log_outputs=args.log_outputs, | |
log_file=args.log_file, | |
temperature=args.temperature, | |
max_tokens=args.max_tokens | |
) | |
benchmark = OpenAIPredictionBenchmark(config) | |
prompts = generate_test_prompts(config.total_requests) | |
if args.profile: | |
requests.post("http://localhost:8000/start_profile") | |
if args.sync: | |
results = benchmark.run_sync_benchmark(prompts) | |
else: | |
results = await benchmark.run_async_benchmark(prompts) | |
benchmark.print_results(results, show_errors=args.show_errors or args.verbose) | |
if args.output: | |
benchmark.save_results(results, args.output) | |
if args.profile: | |
requests.post("http://localhost:8000/stop_profile") | |
if args.log_outputs: | |
print(f"\nDetailed request/response logs saved to: {args.log_file}") | |
if __name__ == "__main__": | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment