Skip to content

Instantly share code, notes, and snippets.

@trashhalo
Created January 25, 2025 15:19
Show Gist options
  • Save trashhalo/eb1854b6e84309c396b930b0276de592 to your computer and use it in GitHub Desktop.
Save trashhalo/eb1854b6e84309c396b930b0276de592 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.8"
# dependencies = [
# "torch",
# "transformers",
# "Pillow",
# "rich",
# ]
# [project]
# name = "look"
# version = "0.1.0"
#
# [project.scripts]
# look = "look:main"
# ///
import os
import sys
import argparse
from pathlib import Path
from PIL import Image
from transformers import AutoProcessor, AutoModelForVision2Seq
from transformers.image_utils import load_image
import logging
import torch
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.logging import RichHandler
from io import BytesIO
console = Console()
logging.basicConfig(
level=logging.INFO,
format="%(message)s",
handlers=[RichHandler(console=console, show_time=False)]
)
logger = logging.getLogger(__name__)
def setup_model():
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console
) as progress:
progress.add_task("Loading model...", total=None)
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
modelid = "HuggingFaceTB/SmolVLM-256M-Instruct"
processor = AutoProcessor.from_pretrained(modelid)
model = AutoModelForVision2Seq.from_pretrained(
modelid,
torch_dtype=torch.float32,
).to(DEVICE)
return processor, model, DEVICE
def load_image_from_stdin():
"""Read image data from stdin"""
# Read all data from stdin first
data = sys.stdin.buffer.read()
if not data:
console.print("[red]Error:[/red] No data received on stdin")
sys.exit(1)
try:
return Image.open(BytesIO(data))
except Exception as e:
console.print(f"[red]Error:[/red] Invalid image data received on stdin: {str(e)}")
sys.exit(1)
def validate_image_path(path):
image_path = Path(path)
if not image_path.exists():
console.print(f"[red]Error:[/red] Image file '{path}' not found")
sys.exit(1)
if not image_path.is_file():
console.print(f"[red]Error:[/red] '{path}' is not a file")
sys.exit(1)
return image_path
def generate_caption(image, prompt, processor, model, device):
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console
) as progress:
task = progress.add_task("Looking at image...", total=None)
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": prompt}
]
},
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[image], return_tensors="pt")
inputs = inputs.to(device)
generated_ids = model.generate(**inputs, max_new_tokens=500)
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
progress.update(task, completed=True)
return extract_assistant_text(generated_texts[0])
def extract_assistant_text(log):
return "\n".join(line[11:] for line in log.splitlines() if line.startswith("Assistant:"))
def main():
parser = argparse.ArgumentParser(
description='Look at images and describe what you see',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
look --at photo.jpg
look --at https://example.com/photo.jpg
look --at - < image.jpg
screencapture -x - | look --at -
"""
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--at', metavar='IMAGE', help='Local image file or URL to look at')
parser.add_argument('--prompt', default="Describe what you see in this image",
help='Specific aspect to look at')
parser.add_argument('--version', action='version', version='%(prog)s 0.1.0')
args = parser.parse_args()
try:
# Setup model
processor, model, device = setup_model()
# Load image
if args.at == '-':
image = load_image_from_stdin()
elif "://" in args.at: # URL check
image = load_image(args.at)
else:
image_path = validate_image_path(args.at)
image = Image.open(image_path)
# Generate caption
caption = generate_caption(image, args.prompt, processor, model, device)
console.print(f"\n{caption}\n")
except KeyboardInterrupt:
console.print("\n[yellow]Stopped looking[/yellow]")
return 1
except Exception as e:
console.print(f"\n[red]Error:[/red] {str(e)}")
return 1
return 0
if __name__ == '__main__':
exit(main())
@trashhalo
Copy link
Author

Screenshot 2025-01-25 at 10 17 43 AM

set tmp (mktemp) && screencapture -s "$tmp" && pipx run look.py --at "$tmp" && rm "$tmp"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment