Created
December 4, 2024 21:59
-
-
Save cpfiffer/d67118bc9471b29c2c3b69fa534be824 to your computer and use it in GitHub Desktop.
Reading road signs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
pip install outlines torch==2.4.0 transformers accelerate pillow rich | |
sudo apt-get install poppler-utils | |
""" | |
from enum import Enum | |
from PIL import Image | |
import outlines | |
import torch | |
from transformers import AutoProcessor | |
from pydantic import BaseModel, Field | |
from typing import List, Optional, Literal | |
from rich import print | |
# To use Pixtral: | |
# from transformers import LlavaForConditionalGeneration | |
# model_name="mistral-community/pixtral-12b" # original magnet model is able to be loaded without issue | |
# model_class=LlavaForConditionalGeneration | |
# To use Qwen-2-VL: | |
from transformers import Qwen2VLForConditionalGeneration | |
model_name = "Qwen/Qwen2-VL-7B-Instruct" | |
model_class = Qwen2VLForConditionalGeneration | |
model = outlines.models.transformers_vision( | |
model_name, | |
model_class=model_class, | |
model_kwargs={ | |
"device_map": "auto", | |
"torch_dtype": torch.bfloat16, | |
}, | |
processor_kwargs={ | |
"device": "cuda", | |
}, | |
) | |
def load_and_resize_image(image_path, max_size=1024): | |
""" | |
Load and resize an image while maintaining aspect ratio | |
Args: | |
image_path: Path to the image file | |
max_size: Maximum dimension (width or height) of the output image | |
Returns: | |
PIL Image: Resized image | |
""" | |
image = Image.open(image_path) | |
# Get current dimensions | |
width, height = image.size | |
# Calculate scaling factor | |
scale = min(max_size / width, max_size / height) | |
# Only resize if image is larger than max_size | |
if scale < 1: | |
new_width = int(width * scale) | |
new_height = int(height * scale) | |
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
return image | |
# Load and resize the image | |
image = load_and_resize_image( | |
"zoomed.png", | |
max_size=2048 | |
) | |
# Define the schema | |
class TrafficSign(BaseModel): | |
sign_type: str | |
description: str | |
raw_text: str | |
class Description(BaseModel): | |
description: str | |
traffic_signs: List[TrafficSign] | |
description_generator = outlines.generate.json( | |
model, | |
Description, | |
) | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "image", | |
"image": image, | |
}, | |
{"type": "text", "text": f""" | |
Please describe this scene, with a focus on traffic signs. | |
Try to find all the traffic signs in the image and describe them. | |
Respond in JSON format, using the following schema: | |
{Description.model_json_schema()} | |
"""}, | |
], | |
} | |
] | |
# Convert the messages to the final prompt | |
processor = AutoProcessor.from_pretrained(model_name) | |
text = processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
print(text) | |
description = description_generator(text, [image]) | |
print(description) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment