Created
January 10, 2024 08:06
-
-
Save vatsalsaglani/5d05b3e40d20c5a8b1b53b9d9eaa103c to your computer and use it in GitHub Desktop.
OpenAI GPT Vision Token Counting
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import re | |
from urllib import request | |
from io import BytesIO | |
import base64 | |
from PIL import Image | |
from typing import Literal | |
def getImageDimensions(image: str): # base64 or url | |
url_regex = r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)' | |
if re.match(url_regex, image): | |
response = request.urlopen(image) | |
image = Image.open(response) | |
return image.size | |
elif re.match(r'data:image\/\w+;base64', image): | |
image = re.sub(r'data:image\/\w+;base64,', '', image) | |
image = Image.open(BytesIO(base64.b64decode(image))) | |
return image.size | |
else: | |
raise ValueError("Image must be a URL or Base64 string") | |
def calculateImageTokens(image: str, detail: Literal["high", "low"] = "low"): | |
assert detail in [ | |
"high", "low" | |
], ValueError(f"Invalid detail option: {detail}. Use 'low' or 'high'") | |
LOW_DETAIL_COST = 85 | |
HIGH_DETAIL_COST_PER_TILE = 170 | |
ADDITIONAL_COST = 85 | |
if detail == "low": | |
return LOW_DETAIL_COST | |
elif detail == "high": | |
width, height = getImageDimensions(image) | |
if max(width, height) >= 2048: | |
ratio = 2048 / max(width, height) | |
width = int(width * ratio) | |
height = int(height * ratio) | |
if min(width, height) > 768: | |
ratio = 768 / min(width, height) | |
width = int(width * ratio) | |
height = int(height * ratio) | |
num_squares = math.ceil(width / 512) * math.ceil(height / 512) | |
total_tokens = num_squares * HIGH_DETAIL_COST_PER_TILE + ADDITIONAL_COST | |
return total_tokens | |
else: | |
raise ValueError( | |
f"Invalid detail option: {detail}. Use 'low' or 'high'") | |
class NaiveContextManagement: | |
def __init__(self, model_name: str): | |
self.encoding = tiktoken.encoding_for_model(model_name) | |
def __count_tokens__(self, content: str): | |
tokens = self.encoding.encode(content) | |
return len(tokens) + 4 | |
def __pad_messages__(self, messages: List[Dict], max_length: int): | |
curr_length = 0 | |
output_messages = [] | |
for message in messages[::-1]: | |
if isinstance(message.get("content"), str): | |
message_tokens = self.__count_tokens__(message.get("content")) | |
elif isinstance(message.get("content"), list): | |
message_tokens = 0 | |
for content in message.get("content"): | |
if content.get("type") == "image_url": | |
message_tokens += calculateImageTokens( | |
content.get("image_url").get("url")) | |
elif content.get("type") == "text": | |
message_tokens += self.__count_tokens__( | |
content.get("text")) | |
if message_tokens > max_length: | |
print( | |
f'Context length exceeded the allowed context length of: {max_length}.' | |
) | |
break | |
output_messages.append(message) | |
curr_length += message_tokens | |
return output_messages[::-1] | |
if __name__ == "__main__": | |
import json | |
ctx = NaiveContextManagement("gpt-4-vision-preview") | |
example_messages = [{ | |
"role": | |
"user", | |
"content": [{ | |
"type": "text", | |
"text": "describe what is in this image?" | |
}, { | |
"type": "image_url", | |
"image_url": { | |
"url": | |
"https://venturebeat.com/wp-content/uploads/2019/03/openai-1.png", | |
"detail": "high" | |
} | |
}] | |
}] | |
print(json.dumps(ctx.__pad_messages__(example_messages, 128_000), | |
indent=4)) | |
example_messages = [{ | |
"role": "system", | |
"content": "You're a helpful assistant" | |
}, { | |
"role": "user", | |
"content": "Hello" | |
}] | |
print(json.dumps(ctx.__pad_messages__(example_messages, 5), indent=4)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment