Created
November 7, 2023 02:56
-
-
Save elecnix/574e571204f79529286faed02415cb79 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This module contains functions for estimating the cost of OpenAI's image recognition API. | |
https://platform.openai.com/docs/guides/vision | |
Usage: | |
import openai_vision_cost | |
cost = openai_vision_cost.estimate_image_cost(image_width, image_height, detail_level) | |
Generated with the help of ChatGPT: | |
https://chat.openai.com/share/bc4e6202-f007-496e-ac18-2f888b93a80f | |
""" | |
def scale_to_fit_square(image_width, image_height, square_size): | |
"""Scale an image to fit within a square of the given size, maintaining aspect ratio, only if it's larger.""" | |
# Only scale down if either dimension is greater than the square_size | |
if image_width > square_size or image_height > square_size: | |
aspect_ratio = image_width / image_height | |
if aspect_ratio > 1: | |
# Width is greater than height | |
return square_size, int(square_size / aspect_ratio) | |
else: | |
# Height is greater than or equal to width | |
return int(square_size * aspect_ratio), square_size | |
else: | |
# If the image is smaller than the square, return its original dimensions | |
return image_width, image_height | |
def scale_to_shortest_side(image_width, image_height, target_short_side_length): | |
"""Scale the image so that the shortest side is the target length, maintaining aspect ratio, only if it's larger.""" | |
# Determine the current shortest side | |
current_shortest_side = min(image_width, image_height) | |
# Only scale down if the current shortest side is greater than the target length | |
if current_shortest_side > target_short_side_length: | |
scaling_factor = target_short_side_length / current_shortest_side | |
scaled_width = int(image_width * scaling_factor) | |
scaled_height = int(image_height * scaling_factor) | |
else: | |
# If the image is smaller than the target, return its original dimensions | |
scaled_width = image_width | |
scaled_height = image_height | |
return scaled_width, scaled_height | |
def count_512px_squares(image_width, image_height): | |
"""Count how many 512x512 squares can fit into an image.""" | |
num_squares_width = -(-image_width // 512) # Ceiling division | |
num_squares_height = -(-image_height // 512) # Ceiling division | |
return num_squares_width * num_squares_height | |
def calculate_tokens(num_squares, detail_level): | |
"""Calculate the token cost based on number of squares and detail level.""" | |
if detail_level == 'low': | |
return 85 | |
elif detail_level == 'high': | |
return num_squares * 170 + 85 | |
else: | |
raise ValueError("Detail level must be 'low' or 'high'.") | |
def estimate_image_cost(image_width, image_height, detail_level): | |
# First, scale the image to fit within a 2048x2048 square if the detail level is high | |
if detail_level == 'high': | |
scaled_width, scaled_height = scale_to_fit_square(image_width, image_height, 2048) | |
# Then, scale the image such that the shortest side is 768px long | |
final_width, final_height = scale_to_shortest_side(scaled_width, scaled_height, 768) | |
# Count how many 512px squares the image consists of | |
num_squares = count_512px_squares(final_width, final_height) | |
else: | |
# If detail is low, we don't need to scale | |
num_squares = 1 # detail: low always counts as one square | |
# Calculate the token cost based on the number of squares and detail level | |
token_cost = calculate_tokens(num_squares, detail_level) | |
return token_cost | |
def test_functions(): | |
# Test scale_to_fit_square | |
assert scale_to_fit_square(3000, 4000, 2048) == (1536, 2048) | |
assert scale_to_fit_square(4000, 3000, 2048) == (2048, 1536) | |
# Test scale_to_shortest_side | |
assert scale_to_shortest_side(3000, 4000, 768) == (768, 1024) | |
assert scale_to_shortest_side(4000, 3000, 768) == (1024, 768) | |
# Test count_512px_squares | |
assert count_512px_squares(1024, 768) == 4 | |
assert count_512px_squares(1536, 2048) == 12 | |
# Test calculate_tokens | |
assert calculate_tokens(4, 'low') == 85 | |
assert calculate_tokens(4, 'high') == 4 * 170 + 85 | |
assert calculate_tokens(12, 'high') == 12 * 170 + 85 | |
# Test estimate_image_cost | |
assert estimate_image_cost(1000, 500, 'low') == 85 | |
assert estimate_image_cost(512, 512, 'high') == 85 + 170 | |
assert estimate_image_cost(1024, 2048, 'high') == 6 * 170 + 85 | |
print("All tests passed.") | |
if __name__ == '__main__': | |
test_functions() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment