Created
March 14, 2024 20:12
-
-
Save anadim/6827bd0744feebc8c67f18ac3a942d1e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Import necessary modules | |
import anthropic | |
import openai | |
import re | |
import matplotlib.pyplot as plt | |
import random | |
# Function to generate a prompt for the models | |
def generate_prompt(a, b): | |
return f"What is {a} + {b}?" | |
# Function to extract the result from model responses | |
def extract_result(text): | |
text = text.replace(",", "") | |
text = text.replace(".", "") | |
# Find the last string of digits in the text | |
match = re.findall(r"\d+", text) | |
if match: | |
return int(match[-1]) | |
else: | |
return None | |
# Function to test a single model with a given prompt | |
def test_model(model_name, prompt, anthropic_client): | |
if model_name.startswith("claude"): | |
response = anthropic_client.messages.create( | |
model=model_name, | |
max_tokens=1000, | |
temperature=0.0, | |
messages=[{"role": "user", "content": prompt}] | |
) | |
return response.content[0].text, extract_result(response.content[0].text) | |
else: | |
response = openai.ChatCompletion.create( | |
model=model_name, | |
messages=[{"role": "user", "content": prompt}], | |
temperature =0 | |
) | |
return response.choices[0].message.content, extract_result(response.choices[0].message.content) | |
# Function to run the experiment across different models and digit lengths | |
def run_experiment(models, digit_lengths, anthropic_client, num_samples=10): | |
results = {model: [] for model in models} | |
for length in digit_lengths: | |
print(f"\n\n\n\n ********** Digit Length: {length} **********") | |
total_relative_error = {model: 0 for model in models} | |
for sample in range(num_samples): | |
a = random.randint(10 ** (length - 1), 10 ** length - 1) | |
b = random.randint(10 ** (length - 1), 10 ** length - 1) | |
expected_result = a + b | |
prompt = generate_prompt(a, b) | |
print(f"\n\nSample {sample + 1}:") | |
# print(f"Prompt: {prompt}") | |
print(f"Expected Result: {expected_result}") | |
for model in models: | |
model_output, predicted_result = test_model(model, prompt, anthropic_client) | |
relative_error = abs(predicted_result - expected_result) / expected_result if predicted_result is not None else 1.0 | |
total_relative_error[model] += relative_error | |
print(f"\nModel: {model}") | |
print(f"Output: {model_output}") | |
print(f"Predicted Result: {predicted_result}") | |
print(f"\n Relative Error: {relative_error:.4f}") | |
for model in models: | |
avg_relative_error = total_relative_error[model] / num_samples | |
results[model].append(avg_relative_error) | |
print(f"\nModel: {model}") | |
print(f"Average Relative Error: {avg_relative_error:.4f}") | |
return results | |
# Function to plot the results of the experiment | |
def plot_results(results, digit_lengths): | |
for model, relative_errors in results.items(): | |
plt.plot(digit_lengths, relative_errors, label=model) | |
plt.xlabel("Number of Digits") | |
plt.ylabel("Relative Error") | |
plt.title("Relative Error vs. Number of Digits") | |
plt.legend() | |
plt.show() | |
# Replace with your actual API keys | |
anthropic_api_key = "key" | |
openai_api_key = "key" | |
# Initialize the anthropic client | |
anthropic_client = anthropic.Anthropic(api_key=anthropic_api_key) | |
# Set the OpenAI API key | |
openai.api_key = openai_api_key | |
# List of models to test | |
models = ["gpt-3.5-turbo", "claude-3-sonnet-20240229"]# ["claude-instant-1.2", "gpt-3.5-turbo"] # "claude-2.0", "claude-2.1", "claude-3-sonnet-20240229", "claude-3-opus-20240229", "gpt-3.5-turbo", "gpt-4"] | |
# Digit lengths to test | |
digit_lengths = range(100, 125, 5) # Update or extend this list as needed | |
# Run the experiment | |
results = run_experiment(models, digit_lengths, anthropic_client) | |
# Plot the results | |
plot_results(results, digit_lengths) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment