Skip to content

Instantly share code, notes, and snippets.

@dennisseah
Last active December 12, 2024 06:54
Show Gist options
  • Save dennisseah/b60e153931579e0c01362a1ab700a0d0 to your computer and use it in GitHub Desktop.
Save dennisseah/b60e153931579e0c01362a1ab700a0d0 to your computer and use it in GitHub Desktop.
GenerativeAI Vision - counting sheep
import base64
import json
import os
from dataclasses import dataclass
from mimetypes import guess_type
from dotenv import load_dotenv
from openai import AsyncAzureOpenAI
from openai.types.chat.chat_completion import ChatCompletion
load_dotenv()
azure_openai_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT", "")
azure_openai_key = os.getenv("AZURE_OPENAI_KEY", "")
azure_openai_api_version = os.getenv("AZURE_OPENAI_API_VERSION", "")
azure_openai_deployed_model_name = os.getenv("AZURE_OPENAI_DEPLOYED_MODEL_NAME", "")
prompt = """
You are a helpful assistant. Please help me to count how many sheep are in the image.
Generate your response in as a single integer number and DO NOT include any other information.
""" # noqa: E501
@dataclass
class TestSet:
image_url: str
description: str
expected_sheep_count: int
predicted_sheep_count: int = 0
@staticmethod
def with_local_image(
image_path: str, description: str, expected_sheep_count: int
) -> "TestSet":
mime_type, _ = guess_type(image_path)
if mime_type is None:
mime_type = "application/octet-stream"
with open(image_path, "rb") as image_file:
base64_encoded_data = base64.b64encode(image_file.read()).decode("utf-8")
return TestSet(
image_url=f"data:{mime_type};base64,{base64_encoded_data}",
description=description,
expected_sheep_count=expected_sheep_count,
)
def asdict(self) -> dict[str, int | str]:
result = self.__dict__
if len(result["image_url"]) > 100:
result["image_url"] = "local image"
return result
test_sets = [
TestSet(
image_url="https://images.pexels.com/photos/2157028/pexels-photo-2157028.jpeg",
description="Simple image with 2 sheep",
expected_sheep_count=2,
),
TestSet(
image_url="https://images.pexels.com/photos/1153756/pexels-photo-1153756.jpeg",
description="Image with 8 sheep. Sorry, this is a complicated one",
expected_sheep_count=8,
),
TestSet(
image_url="https://images.pexels.com/photos/69466/sunset-sheep-dike-nordfriesland-69466.jpeg",
description="silhouette",
expected_sheep_count=1,
),
# "https://images.pexels.com/photos/14191871/pexels-photo-14191871.jpeg"
TestSet.with_local_image(
image_path=os.path.join("images", "cow.jpeg"),
description="a cow, no sheep",
expected_sheep_count=0,
),
]
async def count_sheep(testset: TestSet):
client = AsyncAzureOpenAI(
azure_endpoint=azure_openai_endpoint,
api_key=azure_openai_key,
api_version=azure_openai_api_version,
)
result: ChatCompletion = await client.chat.completions.create(
model=azure_openai_deployed_model_name,
messages=[
{"role": "system", "content": prompt},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": testset.image_url},
}
],
},
],
)
num_sheep = (
int(result.choices[0].message.content)
if result.choices and result.choices[0].message.content
else 0
)
testset.predicted_sheep_count = num_sheep
print(json.dumps(testset.asdict(), indent=2))
async def main():
await asyncio.gather(*[count_sheep(test) for test in test_sets])
if __name__ == "__main__":
import asyncio
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment