Skip to content

Instantly share code, notes, and snippets.

@bwasti
Created August 26, 2025 21:26
Show Gist options
  • Save bwasti/00ec853a4125b600f2d4506bad03f7c9 to your computer and use it in GitHub Desktop.
Save bwasti/00ec853a4125b600f2d4506bad03f7c9 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
OpenAI Prediction API Benchmark Tool
Benchmarks latency and throughput for the OpenAI Completions API with prediction functionality.
Supports custom endpoints (e.g., localhost:8000) for testing vLLM implementations.
"""
import asyncio
import time
import statistics
import json
import argparse
from typing import List, Dict, Any, Tuple
from dataclasses import dataclass, asdict
from concurrent.futures import ThreadPoolExecutor
import aiohttp
import requests
from datetime import datetime
import requests
@dataclass
class BenchmarkConfig:
"""Configuration for benchmark runs"""
base_url: str = "http://localhost:8000"
api_key: str = "test-key"
model: str = "gpt-3.5-turbo"
concurrent_requests: int = 1
total_requests: int = 10
timeout: int = 60
use_prediction: bool = True
prediction_content: str = "The waves crash against the shore"
log_outputs: bool = False
log_file: str = "benchmark_outputs.log"
temperature: float = 0.7
max_tokens: int = 150
@dataclass
class RequestResult:
"""Results from a single API request"""
latency: float
tokens_generated: int
success: bool
error_message: str = ""
prediction_accuracy: float = 0.0
total_tokens: int = 0
prompt_tokens: int = 0
completion_tokens: int = 0
response_content: str = ""
request_id: str = ""
accepted_prediction_tokens: int = 0
rejected_prediction_tokens: int = 0
@dataclass
class BenchmarkResults:
"""Aggregated benchmark results"""
total_requests: int
successful_requests: int
failed_requests: int
avg_latency: float
p50_latency: float
p95_latency: float
p99_latency: float
min_latency: float
max_latency: float
total_tokens: int
avg_tokens_per_second: float
requests_per_second: float
total_duration: float
error_rate: float
class OpenAIPredictionBenchmark:
"""Benchmark tool for OpenAI Prediction API"""
def __init__(self, config: BenchmarkConfig):
self.config = config
self.results: List[RequestResult] = []
self.request_counter = 0
# Set up logging if enabled
if self.config.log_outputs:
import logging
# Create a custom logger for outputs
self.output_logger = logging.getLogger('benchmark_outputs')
self.output_logger.setLevel(logging.INFO)
# Clear any existing handlers to avoid duplicates
self.output_logger.handlers = []
# Create file handler
file_handler = logging.FileHandler(self.config.log_file, mode='w', encoding='utf-8')
file_handler.setLevel(logging.INFO)
# Create formatter (without request_id in the format string)
formatter = logging.Formatter(
'%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(formatter)
# Add handler to logger
self.output_logger.addHandler(file_handler)
self.output_logger.info("="*80)
self.output_logger.info("BENCHMARK OUTPUT LOGGING STARTED")
self.output_logger.info(f"Configuration: {asdict(self.config)}")
self.output_logger.info("="*80)
else:
self.output_logger = None
def _log_request_response(self, request_id: int, prompt: str, result: RequestResult, raw_response: Dict[str, Any] = None):
print('CONTENT', result.completion_tokens, result.response_content)
"""Log request and response details if logging is enabled"""
if not self.output_logger:
return
# Log request details
self.output_logger.info(f"REQUEST #{request_id} DETAILS")
self.output_logger.info(f"REQUEST #{request_id} - Prompt: {prompt[:200]}{'...' if len(prompt) > 200 else ''}")
if self.config.use_prediction:
self.output_logger.info(f"REQUEST #{request_id} - Prediction Content: {self.config.prediction_content[:200]}{'...' if len(self.config.prediction_content) > 200 else ''}")
# Log response details
self.output_logger.info(f"RESPONSE #{request_id} DETAILS")
self.output_logger.info(f"RESPONSE #{request_id} - Success: {result.success}")
if result.success:
self.output_logger.info(f"RESPONSE #{request_id} - Latency: {result.latency:.4f}s")
self.output_logger.info(f"RESPONSE #{request_id} - Tokens - Total: {result.total_tokens}, Prompt: {result.prompt_tokens}, Completion: {result.completion_tokens}")
if result.accepted_prediction_tokens > 0 or result.rejected_prediction_tokens > 0:
self.output_logger.info(f"RESPONSE #{request_id} - Prediction Tokens - Accepted: {result.accepted_prediction_tokens}, Rejected: {result.rejected_prediction_tokens}")
self.output_logger.info(f"RESPONSE #{request_id} - Content: {result.response_content[:500]}{'...' if len(result.response_content) > 500 else ''}")
# Log raw response if available (for debugging)
if raw_response:
self.output_logger.info(f"RESPONSE #{request_id} - Raw Response: {json.dumps(raw_response, indent=2)[:1000]}{'...' if len(json.dumps(raw_response, indent=2)) > 1000 else ''}")
else:
self.output_logger.info(f"RESPONSE #{request_id} - Error: {result.error_message}")
self.output_logger.info("-" * 80)
def _prepare_request_payload(self, prompt: str, use_prediction: bool = True) -> Dict[str, Any]:
"""Prepare the request payload for the API call"""
payload = {
"model": self.config.model,
"messages": [
{"role": "user", "content": prompt}
],
"max_tokens": self.config.max_tokens,
"temperature": self.config.temperature
}
if use_prediction and self.config.use_prediction:
payload["prediction"] = {
"type": "content",
"content": self.config.prediction_content.strip()
}
return payload
async def _make_async_request(self, session: aiohttp.ClientSession, prompt: str) -> RequestResult:
"""Make an async request to the API"""
self.request_counter += 1
request_id = self.request_counter
start_time = time.time()
try:
payload = self._prepare_request_payload(prompt)
headers = {
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json"
}
async with session.post(
f"{self.config.base_url}/v1/chat/completions",
json=payload,
headers=headers,
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
) as response:
end_time = time.time()
latency = end_time - start_time
if response.status == 200:
data = await response.json()
usage = data.get("usage", {})
completion_tokens_details = usage.get("completion_tokens_details", {})
# Extract response content
response_content = ""
if data.get("choices") and len(data["choices"]) > 0:
message = data["choices"][0].get("message", {})
response_content = message.get("content", "")
result = RequestResult(
latency=latency,
tokens_generated=usage.get("completion_tokens", 0),
success=True,
total_tokens=usage.get("total_tokens", 0),
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
response_content=response_content,
request_id=str(request_id),
accepted_prediction_tokens=completion_tokens_details.get("accepted_prediction_tokens", 0),
rejected_prediction_tokens=completion_tokens_details.get("rejected_prediction_tokens", 0)
)
# Log the request and response
self._log_request_response(request_id, prompt, result, data)
return result
else:
error_text = await response.text()
result = RequestResult(
latency=latency,
tokens_generated=0,
success=False,
error_message=f"HTTP {response.status}: {error_text}",
request_id=str(request_id)
)
# Log the error
self._log_request_response(request_id, prompt, result)
return result
except Exception as e:
end_time = time.time()
result = RequestResult(
latency=end_time - start_time,
tokens_generated=0,
success=False,
error_message=str(e),
request_id=str(request_id)
)
# Log the exception
self._log_request_response(request_id, prompt, result)
return result
def _make_sync_request(self, prompt: str) -> RequestResult:
"""Make a synchronous request to the API"""
self.request_counter += 1
request_id = self.request_counter
start_time = time.time()
try:
payload = self._prepare_request_payload(prompt)
headers = {
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json"
}
response = requests.post(
f"{self.config.base_url}/v1/chat/completions",
json=payload,
headers=headers,
timeout=self.config.timeout
)
end_time = time.time()
latency = end_time - start_time
if response.status_code == 200:
data = response.json()
usage = data.get("usage", {})
completion_tokens_details = usage.get("completion_tokens_details", {})
# Extract response content
response_content = ""
if data.get("choices") and len(data["choices"]) > 0:
message = data["choices"][0].get("message", {})
response_content = message.get("content", "")
result = RequestResult(
latency=latency,
tokens_generated=usage.get("completion_tokens", 0),
success=True,
total_tokens=usage.get("total_tokens", 0),
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
response_content=response_content,
request_id=str(request_id),
accepted_prediction_tokens=completion_tokens_details.get("accepted_prediction_tokens", 0),
rejected_prediction_tokens=completion_tokens_details.get("rejected_prediction_tokens", 0)
)
# Log the request and response
self._log_request_response(request_id, prompt, result, data)
return result
else:
result = RequestResult(
latency=latency,
tokens_generated=0,
success=False,
error_message=f"HTTP {response.status_code}: {response.text}",
request_id=str(request_id)
)
# Log the error
self._log_request_response(request_id, prompt, result)
return result
except Exception as e:
end_time = time.time()
result = RequestResult(
latency=end_time - start_time,
tokens_generated=0,
success=False,
error_message=str(e),
request_id=str(request_id)
)
# Log the exception
self._log_request_response(request_id, prompt, result)
return result
async def run_async_benchmark(self, prompts: List[str]) -> BenchmarkResults:
"""Run async benchmark with concurrent requests"""
print(f"Running async benchmark with {self.config.concurrent_requests} concurrent requests...")
start_time = time.time()
connector = aiohttp.TCPConnector(limit=self.config.concurrent_requests * 2)
async with aiohttp.ClientSession(connector=connector) as session:
semaphore = asyncio.Semaphore(self.config.concurrent_requests)
async def bounded_request(prompt):
async with semaphore:
return await self._make_async_request(session, prompt)
tasks = [bounded_request(prompt) for prompt in prompts]
results = await asyncio.gather(*tasks)
end_time = time.time()
total_duration = end_time - start_time
return self._calculate_metrics(results, total_duration)
def run_sync_benchmark(self, prompts: List[str]) -> BenchmarkResults:
"""Run synchronous benchmark"""
print(f"Running sync benchmark with {self.config.concurrent_requests} concurrent requests...")
start_time = time.time()
if self.config.concurrent_requests == 1:
results = [self._make_sync_request(prompt) for prompt in prompts]
else:
with ThreadPoolExecutor(max_workers=self.config.concurrent_requests) as executor:
results = list(executor.map(self._make_sync_request, prompts))
end_time = time.time()
total_duration = end_time - start_time
return self._calculate_metrics(results, total_duration)
def _calculate_metrics(self, results: List[RequestResult], total_duration: float) -> BenchmarkResults:
"""Calculate benchmark metrics from results"""
# Store results for error reporting
self.results = results
successful_results = [r for r in results if r.success]
failed_results = [r for r in results if not r.success]
if not successful_results:
latencies = [r.latency for r in results]
else:
latencies = [r.latency for r in successful_results]
total_tokens = sum(r.total_tokens for r in successful_results)
total_completion_tokens = sum(r.completion_tokens for r in successful_results)
return BenchmarkResults(
total_requests=len(results),
successful_requests=len(successful_results),
failed_requests=len(failed_results),
avg_latency=statistics.mean(latencies) if latencies else 0,
p50_latency=statistics.median(latencies) if latencies else 0,
p95_latency=self._percentile(latencies, 0.95) if latencies else 0,
p99_latency=self._percentile(latencies, 0.99) if latencies else 0,
min_latency=min(latencies) if latencies else 0,
max_latency=max(latencies) if latencies else 0,
total_tokens=total_tokens,
avg_tokens_per_second=total_completion_tokens / total_duration if total_duration > 0 else 0,
requests_per_second=len(successful_results) / total_duration if total_duration > 0 else 0,
total_duration=total_duration,
error_rate=len(failed_results) / len(results) if results else 0
)
@staticmethod
def _percentile(data: List[float], percentile: float) -> float:
"""Calculate percentile from data"""
if not data:
return 0
sorted_data = sorted(data)
index = int(percentile * len(sorted_data))
if index >= len(sorted_data):
return sorted_data[-1]
return sorted_data[index]
def print_results(self, results: BenchmarkResults, show_errors: bool = False):
"""Print benchmark results in a formatted way"""
print("\n" + "="*60)
print("BENCHMARK RESULTS")
print("="*60)
print(f"Timestamp: {datetime.now().isoformat()}")
print(f"Configuration:")
print(f" Endpoint: {self.config.base_url}")
print(f" Model: {self.config.model}")
print(f" Concurrent Requests: {self.config.concurrent_requests}")
print(f" Total Requests: {self.config.total_requests}")
print(f" Use Prediction: {self.config.use_prediction}")
print(f" Temperature: {self.config.temperature}")
print(f" Max Tokens: {self.config.max_tokens}")
print(f"\nRequest Statistics:")
print(f" Total Requests: {results.total_requests}")
print(f" Successful: {results.successful_requests}")
print(f" Failed: {results.failed_requests}")
print(f" Error Rate: {results.error_rate:.2%}")
# Show error details if requested and there are failures
if show_errors and results.failed_requests > 0:
failed_results = [r for r in self.results if not r.success]
print(f"\nError Details:")
error_counts = {}
for i, result in enumerate(failed_results, 1):
error_msg = result.error_message
if error_msg in error_counts:
error_counts[error_msg] += 1
else:
error_counts[error_msg] = 1
for error_msg, count in error_counts.items():
print(f" [{count}x] {error_msg}")
print(f"\nLatency Metrics (seconds):")
print(f" Average: {results.avg_latency:.4f}")
print(f" Median (P50): {results.p50_latency:.4f}")
print(f" P95: {results.p95_latency:.4f}")
print(f" P99: {results.p99_latency:.4f}")
print(f" Min: {results.min_latency:.4f}")
print(f" Max: {results.max_latency:.4f}")
print(f"\nThroughput Metrics:")
print(f" Requests per Second: {results.requests_per_second:.2f}")
print(f" Tokens per Second: {results.avg_tokens_per_second:.2f}")
print(f" Total Tokens Generated: {results.total_tokens}")
print(f" Total Duration: {results.total_duration:.2f}s")
# Show prediction token stats if available
if hasattr(self, 'results') and self.results:
total_accepted = sum(r.accepted_prediction_tokens for r in self.results if r.success)
total_rejected = sum(r.rejected_prediction_tokens for r in self.results if r.success)
if total_accepted > 0 or total_rejected > 0:
print(f" Prediction Tokens - Accepted: {total_accepted}, Rejected: {total_rejected}")
if total_accepted + total_rejected > 0:
acceptance_rate = total_accepted / (total_accepted + total_rejected)
print(f" Prediction Acceptance Rate: {acceptance_rate:.2%}")
print("="*60)
# Log completion if logging enabled
if self.config.log_outputs:
self.output_logger.info("="*80)
self.output_logger.info("BENCHMARK COMPLETED")
self.output_logger.info(f"Results summary: {asdict(results)}")
self.output_logger.info("="*80)
def save_results(self, results: BenchmarkResults, filename: str):
"""Save results to JSON file"""
output = {
"config": asdict(self.config),
"results": asdict(results),
"timestamp": datetime.now().isoformat()
}
with open(filename, 'w') as f:
json.dump(output, f, indent=2)
print(f"Results saved to {filename}")
def generate_test_prompts(count: int) -> List[str]:
"""Generate test prompts for benchmarking"""
base_prompts = [
"Write a poem about the ocean.",
"Write a short story about a robot learning to paint.",
"Explain the concept of machine learning in simple terms.",
"What are the benefits and drawbacks of renewable energy?",
"Describe the process of photosynthesis.",
"Explain how blockchain technology works.",
"What are the main causes of climate change?",
"Describe the history of artificial intelligence.",
"Write a recipe for chocolate chip cookies.",
"Explain the theory of relativity."
]
prompts = []
for i in range(count):
base_prompt = base_prompts[i % len(base_prompts)]
prompts.append(f"{base_prompt} (Request #{i+1})")
return prompts
async def main():
parser = argparse.ArgumentParser(description="Benchmark OpenAI Prediction API")
parser.add_argument("--url", default="http://localhost:8000", help="Base URL for the API")
parser.add_argument("--api-key", default="test-key", help="API key")
parser.add_argument("--model", default="gpt-3.5-turbo", help="Model name")
parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests")
parser.add_argument("--requests", type=int, default=10, help="Total number of requests")
parser.add_argument("--timeout", type=int, default=60, help="Request timeout in seconds")
parser.add_argument("--no-prediction", action="store_true", help="Disable prediction feature")
parser.add_argument("--prediction-content", default="The waves crash against the shore,\nA symphony of sound and motion.\nThe salty breeze caresses my face,\nA reminder of the endless sea.\n\nThe ocean's vast expanse,\nA canvas of endless blue.\nThe horizon stretching out,\nA promise of adventure yet to come.\n\nThe sun sets, casting a golden glow,\nA reminder of the beauty that lies.\nThe moon rises, a beacon of light,\nA reminder of the endless night.\n\nThe ocean's power,\nA force to be reckoned with.\nThe tides, a constant dance,\nA testament to the ocean",
help="Content for prediction")
parser.add_argument("--sync", action="store_true", help="Use synchronous requests instead of async")
parser.add_argument("--profile", action="store_true", help="Profile the run (saves on server side)")
parser.add_argument("--output", help="Output file for results (JSON format)")
parser.add_argument("--show-errors", action="store_true", help="Show detailed error messages for failures")
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose output (alias for --show-errors)")
parser.add_argument("--log-outputs", action="store_true", help="Log all request/response details to file")
parser.add_argument("--log-file", default="benchmark_outputs.log", help="File to log outputs to (default: benchmark_outputs.log)")
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for text generation (default: 0.7)")
parser.add_argument("--max-tokens", type=int, default=150, help="Maximum tokens to generate per request (default: 150)")
args = parser.parse_args()
print(f"Using temperature: {args.temperature}")
print(f"Using max_tokens: {args.max_tokens}")
config = BenchmarkConfig(
base_url=args.url,
api_key=args.api_key,
model=args.model,
concurrent_requests=args.concurrent,
total_requests=args.requests,
timeout=args.timeout,
use_prediction=not args.no_prediction,
prediction_content=args.prediction_content,
log_outputs=args.log_outputs,
log_file=args.log_file,
temperature=args.temperature,
max_tokens=args.max_tokens
)
benchmark = OpenAIPredictionBenchmark(config)
prompts = generate_test_prompts(config.total_requests)
if args.profile:
requests.post("http://localhost:8000/start_profile")
if args.sync:
results = benchmark.run_sync_benchmark(prompts)
else:
results = await benchmark.run_async_benchmark(prompts)
benchmark.print_results(results, show_errors=args.show_errors or args.verbose)
if args.output:
benchmark.save_results(results, args.output)
if args.profile:
requests.post("http://localhost:8000/stop_profile")
if args.log_outputs:
print(f"\nDetailed request/response logs saved to: {args.log_file}")
if __name__ == "__main__":
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment