Skip to content

Instantly share code, notes, and snippets.

@unbracketed
Created February 4, 2025 19:32
Show Gist options
  • Save unbracketed/a898d409b76f51505f3da99db013eb35 to your computer and use it in GitHub Desktop.
Save unbracketed/a898d409b76f51505f3da99db013eb35 to your computer and use it in GitHub Desktop.
Processes a directory of images using a few image-to-text models and outputs tables for comparing the processing times and generated captions
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "click",
# "torch",
# "transformers",
# "Pillow",
# "rich",
# "pandas",
# ]
# ///
import click
import time
from pathlib import Path
from typing import List, Dict
import torch
from transformers import pipeline
from rich.console import Console
from rich.table import Table
from rich import print as rprint
import pandas as pd
import warnings
warnings.filterwarnings('ignore')
# List of models to test
MODELS = [
"ydshieh/vit-gpt2-coco-en",
"Salesforce/blip-image-captioning-large",
"microsoft/git-base-coco",
]
def process_image(image_path: Path, models: List[str]) -> Dict:
"""Process a single image through multiple models and return results."""
results = {
'image': image_path.name,
'captions': {},
'times': {}
}
for model_name in models:
try:
# Initialize model
captioner = pipeline(model=model_name, device=0 if torch.cuda.is_available() else -1)
# Generate caption
caption_start = time.time()
caption = captioner(str(image_path))
end_time = time.time()
# Store results
results['captions'][model_name] = caption[0]['generated_text'] if isinstance(caption, list) else caption
results['times'][model_name] = round(end_time - caption_start, 2)
except Exception as e:
results['captions'][model_name] = f"Error: {str(e)}"
results['times'][model_name] = -1
return results
def create_summary_tables(all_results: List[Dict]) -> tuple:
"""Create summary tables for captions and timing."""
# Prepare data for pandas
caption_data = []
timing_data = []
for result in all_results:
caption_row = {'Image': result['image']}
timing_row = {'Image': result['image']}
caption_row.update(result['captions'])
timing_row.update(result['times'])
caption_data.append(caption_row)
timing_data.append(timing_row)
caption_df = pd.DataFrame(caption_data)
timing_df = pd.DataFrame(timing_data)
return caption_df, timing_df
def display_results(caption_df: pd.DataFrame, timing_df: pd.DataFrame):
"""Display results using rich tables."""
console = Console()
# Caption Results Table
caption_table = Table(title="Generated Captions", show_header=True, header_style="bold magenta")
for col in caption_df.columns:
caption_table.add_column(col, style="cyan", no_wrap=True)
for _, row in caption_df.iterrows():
caption_table.add_row(*[str(x) for x in row])
# Timing Results Table
timing_table = Table(title="Processing Times (seconds)", show_header=True, header_style="bold magenta")
for col in timing_df.columns:
timing_table.add_column(col, style="cyan", justify="right")
for _, row in timing_df.iterrows():
timing_table.add_row(*[str(x) for x in row])
# Display tables
console.print("\n")
console.print(caption_table)
console.print("\n")
console.print(timing_table)
@click.command()
@click.argument('image_dir', type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path))
@click.option('--models', '-m', multiple=True, help='Specific models to use (default: use all built-in models)')
@click.option('--output', '-o', type=click.Path(path_type=Path), help='Save results to CSV files')
def main(image_dir: Path, models: tuple, output: Path):
"""
Process images in a directory through multiple image captioning models.
Compares generated captions and processing times.
"""
# Use specified models or default list
model_list = list(models) if models else MODELS
rprint(f"[bold green]Processing images from: {image_dir}[/bold green]")
rprint(f"[bold blue]Using models: {', '.join(model_list)}[/bold blue]\n")
# Get all image files
image_files = list(image_dir.glob('*.jpg')) + list(image_dir.glob('*.png'))
if not image_files:
rprint("[bold red]No image files (jpg/png) found in directory![/bold red]")
return
# Process all images
all_results = []
with click.progressbar(image_files, label='Processing images') as images:
for img_path in images:
results = process_image(img_path, model_list)
all_results.append(results)
# Create and display summary tables
caption_df, timing_df = create_summary_tables(all_results)
display_results(caption_df, timing_df)
# Save results if output path specified
if output:
output.parent.mkdir(parents=True, exist_ok=True)
caption_df.to_csv(output.with_suffix('.captions.csv'), index=False)
timing_df.to_csv(output.with_suffix('.timing.csv'), index=False)
rprint(f"\n[bold green]Results saved to:{output.with_suffix('.captions.csv')}"
f" and {output.with_suffix('.timing.csv')}[/bold green]")
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment