Last active
December 12, 2024 06:54
-
-
Save dennisseah/b60e153931579e0c01362a1ab700a0d0 to your computer and use it in GitHub Desktop.
GenerativeAI Vision - counting sheep
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import base64 | |
import json | |
import os | |
from dataclasses import dataclass | |
from mimetypes import guess_type | |
from dotenv import load_dotenv | |
from openai import AsyncAzureOpenAI | |
from openai.types.chat.chat_completion import ChatCompletion | |
load_dotenv() | |
azure_openai_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT", "") | |
azure_openai_key = os.getenv("AZURE_OPENAI_KEY", "") | |
azure_openai_api_version = os.getenv("AZURE_OPENAI_API_VERSION", "") | |
azure_openai_deployed_model_name = os.getenv("AZURE_OPENAI_DEPLOYED_MODEL_NAME", "") | |
prompt = """ | |
You are a helpful assistant. Please help me to count how many sheep are in the image. | |
Generate your response in as a single integer number and DO NOT include any other information. | |
""" # noqa: E501 | |
@dataclass | |
class TestSet: | |
image_url: str | |
description: str | |
expected_sheep_count: int | |
predicted_sheep_count: int = 0 | |
@staticmethod | |
def with_local_image( | |
image_path: str, description: str, expected_sheep_count: int | |
) -> "TestSet": | |
mime_type, _ = guess_type(image_path) | |
if mime_type is None: | |
mime_type = "application/octet-stream" | |
with open(image_path, "rb") as image_file: | |
base64_encoded_data = base64.b64encode(image_file.read()).decode("utf-8") | |
return TestSet( | |
image_url=f"data:{mime_type};base64,{base64_encoded_data}", | |
description=description, | |
expected_sheep_count=expected_sheep_count, | |
) | |
def asdict(self) -> dict[str, int | str]: | |
result = self.__dict__ | |
if len(result["image_url"]) > 100: | |
result["image_url"] = "local image" | |
return result | |
test_sets = [ | |
TestSet( | |
image_url="https://images.pexels.com/photos/2157028/pexels-photo-2157028.jpeg", | |
description="Simple image with 2 sheep", | |
expected_sheep_count=2, | |
), | |
TestSet( | |
image_url="https://images.pexels.com/photos/1153756/pexels-photo-1153756.jpeg", | |
description="Image with 8 sheep. Sorry, this is a complicated one", | |
expected_sheep_count=8, | |
), | |
TestSet( | |
image_url="https://images.pexels.com/photos/69466/sunset-sheep-dike-nordfriesland-69466.jpeg", | |
description="silhouette", | |
expected_sheep_count=1, | |
), | |
# "https://images.pexels.com/photos/14191871/pexels-photo-14191871.jpeg" | |
TestSet.with_local_image( | |
image_path=os.path.join("images", "cow.jpeg"), | |
description="a cow, no sheep", | |
expected_sheep_count=0, | |
), | |
] | |
async def count_sheep(testset: TestSet): | |
client = AsyncAzureOpenAI( | |
azure_endpoint=azure_openai_endpoint, | |
api_key=azure_openai_key, | |
api_version=azure_openai_api_version, | |
) | |
result: ChatCompletion = await client.chat.completions.create( | |
model=azure_openai_deployed_model_name, | |
messages=[ | |
{"role": "system", "content": prompt}, | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "image_url", | |
"image_url": {"url": testset.image_url}, | |
} | |
], | |
}, | |
], | |
) | |
num_sheep = ( | |
int(result.choices[0].message.content) | |
if result.choices and result.choices[0].message.content | |
else 0 | |
) | |
testset.predicted_sheep_count = num_sheep | |
print(json.dumps(testset.asdict(), indent=2)) | |
async def main(): | |
await asyncio.gather(*[count_sheep(test) for test in test_sets]) | |
if __name__ == "__main__": | |
import asyncio | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment