Created
February 4, 2025 19:32
-
-
Save unbracketed/a898d409b76f51505f3da99db013eb35 to your computer and use it in GitHub Desktop.
Processes a directory of images using a few image-to-text models and outputs tables for comparing the processing times and generated captions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# /// script | |
# requires-python = ">=3.12" | |
# dependencies = [ | |
# "click", | |
# "torch", | |
# "transformers", | |
# "Pillow", | |
# "rich", | |
# "pandas", | |
# ] | |
# /// | |
import click | |
import time | |
from pathlib import Path | |
from typing import List, Dict | |
import torch | |
from transformers import pipeline | |
from rich.console import Console | |
from rich.table import Table | |
from rich import print as rprint | |
import pandas as pd | |
import warnings | |
warnings.filterwarnings('ignore') | |
# List of models to test | |
MODELS = [ | |
"ydshieh/vit-gpt2-coco-en", | |
"Salesforce/blip-image-captioning-large", | |
"microsoft/git-base-coco", | |
] | |
def process_image(image_path: Path, models: List[str]) -> Dict: | |
"""Process a single image through multiple models and return results.""" | |
results = { | |
'image': image_path.name, | |
'captions': {}, | |
'times': {} | |
} | |
for model_name in models: | |
try: | |
# Initialize model | |
captioner = pipeline(model=model_name, device=0 if torch.cuda.is_available() else -1) | |
# Generate caption | |
caption_start = time.time() | |
caption = captioner(str(image_path)) | |
end_time = time.time() | |
# Store results | |
results['captions'][model_name] = caption[0]['generated_text'] if isinstance(caption, list) else caption | |
results['times'][model_name] = round(end_time - caption_start, 2) | |
except Exception as e: | |
results['captions'][model_name] = f"Error: {str(e)}" | |
results['times'][model_name] = -1 | |
return results | |
def create_summary_tables(all_results: List[Dict]) -> tuple: | |
"""Create summary tables for captions and timing.""" | |
# Prepare data for pandas | |
caption_data = [] | |
timing_data = [] | |
for result in all_results: | |
caption_row = {'Image': result['image']} | |
timing_row = {'Image': result['image']} | |
caption_row.update(result['captions']) | |
timing_row.update(result['times']) | |
caption_data.append(caption_row) | |
timing_data.append(timing_row) | |
caption_df = pd.DataFrame(caption_data) | |
timing_df = pd.DataFrame(timing_data) | |
return caption_df, timing_df | |
def display_results(caption_df: pd.DataFrame, timing_df: pd.DataFrame): | |
"""Display results using rich tables.""" | |
console = Console() | |
# Caption Results Table | |
caption_table = Table(title="Generated Captions", show_header=True, header_style="bold magenta") | |
for col in caption_df.columns: | |
caption_table.add_column(col, style="cyan", no_wrap=True) | |
for _, row in caption_df.iterrows(): | |
caption_table.add_row(*[str(x) for x in row]) | |
# Timing Results Table | |
timing_table = Table(title="Processing Times (seconds)", show_header=True, header_style="bold magenta") | |
for col in timing_df.columns: | |
timing_table.add_column(col, style="cyan", justify="right") | |
for _, row in timing_df.iterrows(): | |
timing_table.add_row(*[str(x) for x in row]) | |
# Display tables | |
console.print("\n") | |
console.print(caption_table) | |
console.print("\n") | |
console.print(timing_table) | |
@click.command() | |
@click.argument('image_dir', type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path)) | |
@click.option('--models', '-m', multiple=True, help='Specific models to use (default: use all built-in models)') | |
@click.option('--output', '-o', type=click.Path(path_type=Path), help='Save results to CSV files') | |
def main(image_dir: Path, models: tuple, output: Path): | |
""" | |
Process images in a directory through multiple image captioning models. | |
Compares generated captions and processing times. | |
""" | |
# Use specified models or default list | |
model_list = list(models) if models else MODELS | |
rprint(f"[bold green]Processing images from: {image_dir}[/bold green]") | |
rprint(f"[bold blue]Using models: {', '.join(model_list)}[/bold blue]\n") | |
# Get all image files | |
image_files = list(image_dir.glob('*.jpg')) + list(image_dir.glob('*.png')) | |
if not image_files: | |
rprint("[bold red]No image files (jpg/png) found in directory![/bold red]") | |
return | |
# Process all images | |
all_results = [] | |
with click.progressbar(image_files, label='Processing images') as images: | |
for img_path in images: | |
results = process_image(img_path, model_list) | |
all_results.append(results) | |
# Create and display summary tables | |
caption_df, timing_df = create_summary_tables(all_results) | |
display_results(caption_df, timing_df) | |
# Save results if output path specified | |
if output: | |
output.parent.mkdir(parents=True, exist_ok=True) | |
caption_df.to_csv(output.with_suffix('.captions.csv'), index=False) | |
timing_df.to_csv(output.with_suffix('.timing.csv'), index=False) | |
rprint(f"\n[bold green]Results saved to:{output.with_suffix('.captions.csv')}" | |
f" and {output.with_suffix('.timing.csv')}[/bold green]") | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment