Created
January 25, 2025 15:19
-
-
Save trashhalo/eb1854b6e84309c396b930b0276de592 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# /// script | |
# requires-python = ">=3.8" | |
# dependencies = [ | |
# "torch", | |
# "transformers", | |
# "Pillow", | |
# "rich", | |
# ] | |
# [project] | |
# name = "look" | |
# version = "0.1.0" | |
# | |
# [project.scripts] | |
# look = "look:main" | |
# /// | |
import os | |
import sys | |
import argparse | |
from pathlib import Path | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModelForVision2Seq | |
from transformers.image_utils import load_image | |
import logging | |
import torch | |
from rich.console import Console | |
from rich.progress import Progress, SpinnerColumn, TextColumn | |
from rich.logging import RichHandler | |
from io import BytesIO | |
console = Console() | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(message)s", | |
handlers=[RichHandler(console=console, show_time=False)] | |
) | |
logger = logging.getLogger(__name__) | |
def setup_model(): | |
with Progress( | |
SpinnerColumn(), | |
TextColumn("[progress.description]{task.description}"), | |
console=console | |
) as progress: | |
progress.add_task("Loading model...", total=None) | |
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu" | |
modelid = "HuggingFaceTB/SmolVLM-256M-Instruct" | |
processor = AutoProcessor.from_pretrained(modelid) | |
model = AutoModelForVision2Seq.from_pretrained( | |
modelid, | |
torch_dtype=torch.float32, | |
).to(DEVICE) | |
return processor, model, DEVICE | |
def load_image_from_stdin(): | |
"""Read image data from stdin""" | |
# Read all data from stdin first | |
data = sys.stdin.buffer.read() | |
if not data: | |
console.print("[red]Error:[/red] No data received on stdin") | |
sys.exit(1) | |
try: | |
return Image.open(BytesIO(data)) | |
except Exception as e: | |
console.print(f"[red]Error:[/red] Invalid image data received on stdin: {str(e)}") | |
sys.exit(1) | |
def validate_image_path(path): | |
image_path = Path(path) | |
if not image_path.exists(): | |
console.print(f"[red]Error:[/red] Image file '{path}' not found") | |
sys.exit(1) | |
if not image_path.is_file(): | |
console.print(f"[red]Error:[/red] '{path}' is not a file") | |
sys.exit(1) | |
return image_path | |
def generate_caption(image, prompt, processor, model, device): | |
with Progress( | |
SpinnerColumn(), | |
TextColumn("[progress.description]{task.description}"), | |
console=console | |
) as progress: | |
task = progress.add_task("Looking at image...", total=None) | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image"}, | |
{"type": "text", "text": prompt} | |
] | |
}, | |
] | |
prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | |
inputs = processor(text=prompt, images=[image], return_tensors="pt") | |
inputs = inputs.to(device) | |
generated_ids = model.generate(**inputs, max_new_tokens=500) | |
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
progress.update(task, completed=True) | |
return extract_assistant_text(generated_texts[0]) | |
def extract_assistant_text(log): | |
return "\n".join(line[11:] for line in log.splitlines() if line.startswith("Assistant:")) | |
def main(): | |
parser = argparse.ArgumentParser( | |
description='Look at images and describe what you see', | |
formatter_class=argparse.RawDescriptionHelpFormatter, | |
epilog=""" | |
Examples: | |
look --at photo.jpg | |
look --at https://example.com/photo.jpg | |
look --at - < image.jpg | |
screencapture -x - | look --at - | |
""" | |
) | |
group = parser.add_mutually_exclusive_group(required=True) | |
group.add_argument('--at', metavar='IMAGE', help='Local image file or URL to look at') | |
parser.add_argument('--prompt', default="Describe what you see in this image", | |
help='Specific aspect to look at') | |
parser.add_argument('--version', action='version', version='%(prog)s 0.1.0') | |
args = parser.parse_args() | |
try: | |
# Setup model | |
processor, model, device = setup_model() | |
# Load image | |
if args.at == '-': | |
image = load_image_from_stdin() | |
elif "://" in args.at: # URL check | |
image = load_image(args.at) | |
else: | |
image_path = validate_image_path(args.at) | |
image = Image.open(image_path) | |
# Generate caption | |
caption = generate_caption(image, args.prompt, processor, model, device) | |
console.print(f"\n{caption}\n") | |
except KeyboardInterrupt: | |
console.print("\n[yellow]Stopped looking[/yellow]") | |
return 1 | |
except Exception as e: | |
console.print(f"\n[red]Error:[/red] {str(e)}") | |
return 1 | |
return 0 | |
if __name__ == '__main__': | |
exit(main()) |
Author
trashhalo
commented
Jan 25, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment