Skip to content

Instantly share code, notes, and snippets.

@cpfiffer
Created December 4, 2024 21:59
Show Gist options
  • Save cpfiffer/d67118bc9471b29c2c3b69fa534be824 to your computer and use it in GitHub Desktop.
Save cpfiffer/d67118bc9471b29c2c3b69fa534be824 to your computer and use it in GitHub Desktop.
Reading road signs
"""
pip install outlines torch==2.4.0 transformers accelerate pillow rich
sudo apt-get install poppler-utils
"""
from enum import Enum
from PIL import Image
import outlines
import torch
from transformers import AutoProcessor
from pydantic import BaseModel, Field
from typing import List, Optional, Literal
from rich import print
# To use Pixtral:
# from transformers import LlavaForConditionalGeneration
# model_name="mistral-community/pixtral-12b" # original magnet model is able to be loaded without issue
# model_class=LlavaForConditionalGeneration
# To use Qwen-2-VL:
from transformers import Qwen2VLForConditionalGeneration
model_name = "Qwen/Qwen2-VL-7B-Instruct"
model_class = Qwen2VLForConditionalGeneration
model = outlines.models.transformers_vision(
model_name,
model_class=model_class,
model_kwargs={
"device_map": "auto",
"torch_dtype": torch.bfloat16,
},
processor_kwargs={
"device": "cuda",
},
)
def load_and_resize_image(image_path, max_size=1024):
"""
Load and resize an image while maintaining aspect ratio
Args:
image_path: Path to the image file
max_size: Maximum dimension (width or height) of the output image
Returns:
PIL Image: Resized image
"""
image = Image.open(image_path)
# Get current dimensions
width, height = image.size
# Calculate scaling factor
scale = min(max_size / width, max_size / height)
# Only resize if image is larger than max_size
if scale < 1:
new_width = int(width * scale)
new_height = int(height * scale)
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
return image
# Load and resize the image
image = load_and_resize_image(
"zoomed.png",
max_size=2048
)
# Define the schema
class TrafficSign(BaseModel):
sign_type: str
description: str
raw_text: str
class Description(BaseModel):
description: str
traffic_signs: List[TrafficSign]
description_generator = outlines.generate.json(
model,
Description,
)
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
},
{"type": "text", "text": f"""
Please describe this scene, with a focus on traffic signs.
Try to find all the traffic signs in the image and describe them.
Respond in JSON format, using the following schema:
{Description.model_json_schema()}
"""},
],
}
]
# Convert the messages to the final prompt
processor = AutoProcessor.from_pretrained(model_name)
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
print(text)
description = description_generator(text, [image])
print(description)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment