Created
July 18, 2024 18:40
-
-
Save nchudleigh/84b8f61e3c923324759a26d7e21660b5 to your computer and use it in GitHub Desktop.
Speed test of openai models
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import os | |
from openai import OpenAI | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from dotenv import load_dotenv | |
# Install required packages | |
# pip install openai matplotlib pandas python-dotenv | |
# Create a .env file in the same directory as this script | |
# Add your OpenAI API key to the .env file as OPENAI_API_KEY | |
# Load environment variables from .env file | |
load_dotenv() | |
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY')) | |
# Fetch the list of available models | |
def fetch_models(): | |
try: | |
models = client.models.list() | |
return [model.id for model in models.data] | |
except Exception as e: | |
print(f"Error fetching models: {e}") | |
return [] | |
# Sample messages for testing | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": "Write a short story about a robot learning to love."} | |
] | |
# Function to test a model | |
def test_model(model_name, messages): | |
try: | |
start_time = time.time() | |
response = client.chat.completions.create(model=model_name, | |
messages=messages, | |
max_tokens=100) | |
end_time = time.time() | |
response_time = end_time - start_time | |
response_text = response.choices[0].message.content.strip() | |
response_length = len(response_text.split()) | |
tokens_per_second = response_length / response_time | |
return { | |
'model': model_name, | |
'response_time': response_time, | |
'response_length': response_length, | |
'tokens_per_second': tokens_per_second, | |
'response_text': response_text | |
} | |
except Exception as e: | |
print(f"Error testing model {model_name}: {e}") | |
return None | |
# Fetch models and test them | |
models = fetch_models() | |
if not models: | |
print("No models were fetched.") | |
else: | |
print(f"Fetched models: {models}") | |
results = [] | |
# Filter models to only include relevant ones (e.g., GPT models) | |
relevant_models = [model for model in models if "gpt" in model] | |
if not relevant_models: | |
print("No relevant models found.") | |
else: | |
print(f"Relevant models: {relevant_models}") | |
# Test all relevant models and collect results | |
for model in relevant_models: | |
result = test_model(model, messages) | |
if result: | |
results.append(result) | |
# Display results as they are computed | |
print(f"Model: {result['model']}") | |
print(f"Response Time: {result['response_time']:.2f} seconds") | |
print(f"Response Length: {result['response_length']} words") | |
print(f"Tokens per Second: {result['tokens_per_second']:.2f}") | |
print(f"Response: {result['response_text']}\n") | |
# Check if results are collected | |
if not results: | |
print("No results were collected.") | |
else: | |
# Create a DataFrame from results | |
df = pd.DataFrame(results) | |
# Display DataFrame | |
print(df) | |
# Plot results | |
plt.figure(figsize=(10, 6)) | |
plt.bar(df['model'], df['tokens_per_second'], color='skyblue') | |
plt.xlabel('Model') | |
plt.ylabel('Tokens per Second') | |
plt.title('Comparison of Tokens per Second for Each Model') | |
plt.xticks(rotation=45) | |
plt.tight_layout() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment