Created
October 27, 2025 06:15
-
-
Save kylehowells/d7daed03d7d3b42ffa4bb9868168c439 to your computer and use it in GitHub Desktop.
Demo of using the OpenAI CLIP and SigLIP 2 models to do text and image embedding search.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Simple demo showing how to generate and compare text and image embeddings. | |
| Uses CLIP (openai/clip-vit-base-patch16) model for both text and image embeddings. | |
| """ | |
| import torch | |
| from PIL import Image | |
| from transformers import CLIPProcessor, CLIPModel | |
| import numpy as np | |
| import torch.nn.functional as F | |
| # Image paths | |
| IMAGE_PATHS = [ | |
| "./lake-smaller.jpg", | |
| "./cats.jpg", | |
| "./bmw-i7.jpg", | |
| "./beech-tree.jpg", | |
| ] | |
| # Global model instances (lazy loaded) | |
| clip_model = None | |
| clip_processor = None | |
| def get_device(): | |
| """Get the best available device (MPS for Mac, CUDA for NVIDIA, or CPU).""" | |
| if torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| elif torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| return torch.device("cpu") | |
| def get_clip_model(): | |
| """Load and return the CLIP model and processor.""" | |
| global clip_model, clip_processor | |
| if clip_model is None: | |
| device = get_device() | |
| print(f"Loading CLIP model on {device}...") | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16") | |
| clip_model = clip_model.to(device) | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") | |
| print("Model loaded!\n") | |
| return clip_model, clip_processor | |
| def generate_image_embedding(image_path): | |
| """ | |
| Generate an L2-normalized image embedding. | |
| Args: | |
| image_path: Path to the image file | |
| Returns: | |
| Tuple of (filename, embedding vector as numpy array) | |
| """ | |
| device = get_device() | |
| model, processor = get_clip_model() | |
| # Load and convert image to RGB | |
| image = Image.open(image_path).convert("RGB") | |
| # Generate embedding | |
| with torch.no_grad(): | |
| inputs = processor(images=image, return_tensors="pt").to(device) | |
| features = model.get_image_features(**inputs) | |
| # L2 normalize for proper cosine similarity | |
| features = F.normalize(features, p=2, dim=-1) | |
| embedding = features.cpu().numpy()[0] | |
| # Extract just the filename from the path | |
| filename = image_path.split("/")[-1] | |
| return filename, embedding | |
| def generate_text_embedding(query_text): | |
| """ | |
| Generate an L2-normalized text embedding. | |
| Args: | |
| query_text: Text query string | |
| Returns: | |
| Embedding vector as numpy array | |
| """ | |
| device = get_device() | |
| model, processor = get_clip_model() | |
| with torch.no_grad(): | |
| inputs = processor(text=[query_text], return_tensors="pt", padding=True).to(device) | |
| features = model.get_text_features(**inputs) | |
| # L2 normalize | |
| features = F.normalize(features, p=2, dim=-1) | |
| embedding = features.cpu().numpy()[0] | |
| return embedding | |
| def cosine_similarity(vec1, vec2): | |
| """ | |
| Calculate cosine similarity between two vectors. | |
| Since vectors are already L2-normalized, this is just the dot product. | |
| Args: | |
| vec1, vec2: Numpy arrays | |
| Returns: | |
| Similarity score (higher is better, range 0-1) | |
| """ | |
| return np.dot(vec1, vec2) | |
| def load_images(): | |
| """ | |
| Load all images and generate their embeddings. | |
| Returns: | |
| List of tuples: [(filename, embedding), ...] | |
| """ | |
| print("Loading and embedding images...") | |
| embeddings = [] | |
| for image_path in IMAGE_PATHS: | |
| print(f" Processing {image_path.split('/')[-1]}...") | |
| filename, embedding = generate_image_embedding(image_path) | |
| embeddings.append((filename, embedding)) | |
| print(f"\nGenerated embeddings for {len(embeddings)} images\n") | |
| return embeddings | |
| def search(query_text, image_embeddings): | |
| """ | |
| Search for images matching the text query. | |
| Args: | |
| query_text: Text description to search for | |
| image_embeddings: List of (filename, embedding) tuples | |
| Returns: | |
| List of (filename, similarity_score) tuples sorted by similarity (best first) | |
| """ | |
| print(f"Searching for: '{query_text}'") | |
| # Generate text embedding | |
| text_embedding = generate_text_embedding(query_text) | |
| # Compare with all image embeddings | |
| results = [] | |
| for filename, image_embedding in image_embeddings: | |
| similarity = cosine_similarity(text_embedding, image_embedding) | |
| results.append((filename, similarity)) | |
| # Sort by similarity (highest first) | |
| results.sort(key=lambda x: x[1], reverse=True) | |
| print("\nResults:") | |
| print("-" * 50) | |
| for filename, similarity in results: | |
| print(f"{filename:30s} | Similarity: {similarity:.4f}") | |
| print("-" * 50) | |
| print() | |
| return results | |
| def main(): | |
| """Main demo function.""" | |
| print("=" * 50) | |
| print("Text and Image Embedding Demo (CLIP)") | |
| print("=" * 50) | |
| print() | |
| # Load all images and generate embeddings | |
| image_embeddings = load_images() | |
| # Run some example searches | |
| print("=" * 50) | |
| print("Example Searches") | |
| print("=" * 50) | |
| print() | |
| search("a beautiful lake with mountains", image_embeddings) | |
| search("cats sitting together", image_embeddings) | |
| search("luxury car", image_embeddings) | |
| search("tree in nature", image_embeddings) | |
| search("water and nature", image_embeddings) | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| torch | |
| torchvision | |
| transformers | |
| pillow | |
| numpy | |
| sqlite-vec |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Search based on images categories and image embeddings saved in a local SQLite database. | |
| Each frame of a video will have multiple categories. | |
| However, it will only have one embedding vector. | |
| Image categories are saved with a video natural key, timestamp/frame number and a category name. | |
| Image embeddings are saved with a video natural key, timestamp/frame number and an embedding vector. | |
| The image search system will use either the image categories and embeddings to search for images that match the query. | |
| You provide a query, which is either matched against the image categories or tokenized and matched against the image embeddings using the cosine similarity score. | |
| The image search system will return a list of images that match the query. | |
| Each image will have the following information: | |
| - The video natural key (which video is this from) | |
| - The frame number (which frame is this from) | |
| Then either the category names if searching by category. | |
| Or the cosine similarity score if searching by embedding. | |
| """ | |
| import os | |
| # Use pysqlite3 instead of built-in sqlite3 for extension support | |
| try: | |
| import pysqlite3 as sqlite3 | |
| except ImportError: | |
| import sqlite3 | |
| import sqlite_vec | |
| import json | |
| from typing import List, Tuple, Optional | |
| import argparse | |
| from transformers import pipeline | |
| from typing import Dict | |
| import time | |
| import subprocess | |
| import tempfile | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModel, CLIPProcessor, CLIPModel | |
| import numpy as np | |
| import torch.nn.functional as F | |
| clf = None | |
| def classify_image(image_path: str) -> List[Dict[str, float]]: | |
| """ | |
| Classify an image using a pre-trained model. | |
| """ | |
| global clf | |
| if clf is None: | |
| device = get_device() | |
| # Map device to pipeline device format (pipeline uses -1 for cpu, 0+ for gpu indices) | |
| if device.type == "cuda": | |
| pipeline_device = 0 | |
| elif device.type == "mps": | |
| pipeline_device = 0 # MPS also uses 0 | |
| else: | |
| pipeline_device = -1 # CPU | |
| clf = pipeline("image-classification", "google/vit-large-patch16-224", device=pipeline_device, use_fast=True) | |
| start_time = time.time() | |
| # Classify the image | |
| results = clf(image_path) | |
| # Filter out results with score less than 0.1 | |
| results = [result for result in results if result["score"] > 0.1] | |
| end_time = time.time() | |
| print(f"Time taken to classify image: {end_time - start_time} seconds") | |
| # Return the results | |
| return results | |
| # Model instances (lazy loaded) | |
| siglip_model = None | |
| siglip_processor = None | |
| clip_model = None | |
| clip_processor = None | |
| def get_siglip_model() -> Tuple[AutoModel, AutoProcessor]: | |
| """ | |
| Get the SigLIP2 model and processor. | |
| Uses AutoProcessor for both images and text (correct approach from reference projects). | |
| """ | |
| global siglip_model, siglip_processor | |
| if siglip_model is None: | |
| device = get_device() | |
| siglip_model = AutoModel.from_pretrained("google/siglip2-so400m-patch16-naflex") | |
| siglip_model = siglip_model.to(device) | |
| siglip_processor = AutoProcessor.from_pretrained("google/siglip2-so400m-patch16-naflex") | |
| return siglip_model, siglip_processor | |
| def get_clip_model() -> Tuple[CLIPModel, CLIPProcessor]: | |
| """ | |
| Get the CLIP model and processor. | |
| """ | |
| global clip_model, clip_processor | |
| if clip_model is None: | |
| device = get_device() | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16") | |
| clip_model = clip_model.to(device) | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") | |
| return clip_model, clip_processor | |
| def get_device() -> torch.device: | |
| """ | |
| Get the device to use for the model. | |
| """ | |
| if torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| elif torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| return torch.device("cpu") | |
| def generate_image_embedding(image_path: str, model: str = "siglip2") -> np.ndarray: | |
| """ | |
| Generate an image embedding using either SigLIP2 or CLIP. | |
| Embeddings are L2-normalized for proper cosine similarity comparison. | |
| Args: | |
| image_path: Path to the image file | |
| model: Model to use - "siglip2" or "clip" (default: "siglip2") | |
| Returns: | |
| L2-normalized embedding vector as numpy array | |
| """ | |
| device = get_device() | |
| image = Image.open(image_path).convert("RGB") | |
| start_time = time.time() | |
| if model == "clip": | |
| clip_model, clip_processor = get_clip_model() | |
| with torch.no_grad(): | |
| inputs = clip_processor(images=image, return_tensors="pt").to(device) | |
| features = clip_model.get_image_features(**inputs) | |
| # L2 normalize | |
| features = F.normalize(features, p=2, dim=-1) | |
| image_embedding = features.cpu().numpy()[0] | |
| else: # siglip2 | |
| siglip_model, siglip_processor = get_siglip_model() | |
| with torch.no_grad(): | |
| inputs = siglip_processor(images=image, return_tensors="pt").to(device) | |
| features = siglip_model.get_image_features(**inputs) | |
| # L2 normalize | |
| features = F.normalize(features, p=2, dim=-1) | |
| image_embedding = features.cpu().numpy()[0] | |
| end_time = time.time() | |
| print(f"Time taken to generate {model} image embedding: {end_time - start_time} seconds") | |
| return image_embedding | |
| def get_text_embedding(query_text: str, model: str = "siglip2") -> np.ndarray: | |
| """ | |
| Generate a text embedding to query against the image embeddings. | |
| Embeddings are L2-normalized for proper cosine similarity comparison. | |
| Args: | |
| query_text: Text query string | |
| model: Model to use - "siglip2" or "clip" (default: "siglip2") | |
| Returns: | |
| L2-normalized embedding vector as numpy array | |
| """ | |
| device = get_device() | |
| start_time = time.time() | |
| if model == "clip": | |
| clip_model, clip_processor = get_clip_model() | |
| with torch.no_grad(): | |
| inputs = clip_processor(text=[query_text], return_tensors="pt", padding=True).to(device) | |
| features = clip_model.get_text_features(**inputs) | |
| # L2 normalize | |
| features = F.normalize(features, p=2, dim=-1) | |
| query_text_embedding = features.cpu().numpy()[0] | |
| else: # siglip2 | |
| siglip_model, siglip_processor = get_siglip_model() | |
| with torch.no_grad(): | |
| # Use padding="max_length", max_length=64 | |
| # This is important! Using just padding=True produces very different embeddings | |
| # (cosine similarity ~0.7 vs 1.0). Two reference projects use max_length=64. | |
| inputs = siglip_processor(text=[query_text], return_tensors="pt", padding="max_length", max_length=64).to(device) | |
| features = siglip_model.get_text_features(**inputs) | |
| # L2 normalize | |
| features = F.normalize(features, p=2, dim=-1) | |
| query_text_embedding = features.cpu().numpy()[0] | |
| end_time = time.time() | |
| print(f"Time taken to generate {model} text embedding: {end_time - start_time} seconds") | |
| return query_text_embedding | |
| # MARK: - Image Search Database | |
| class ImageSearch: | |
| """Image search system using FTS5 for categories and vector embeddings.""" | |
| db_path: str | |
| def __init__(self, db_path: str = None): | |
| """ | |
| Initialize the image search system. | |
| Args: | |
| db_path: Path to the SQLite database file | |
| """ | |
| self.db_path = db_path if db_path else os.path.join(os.path.dirname(__file__), "images_database.db") | |
| self._init_db() | |
| def _init_db(self) -> None: | |
| """Initialize the SQLite database with FTS5 and vector tables.""" | |
| conn = sqlite3.connect(self.db_path) | |
| conn.enable_load_extension(True) | |
| sqlite_vec.load(conn) | |
| conn.enable_load_extension(False) | |
| # Create FTS5 virtual table for category search | |
| # Categories are searchable text labels (e.g., "person", "landscape", "indoor") | |
| conn.execute(""" | |
| CREATE VIRTUAL TABLE IF NOT EXISTS image_categories | |
| USING fts5( | |
| natural_key, | |
| frame_number, | |
| category_name, | |
| score | |
| ) | |
| """) | |
| # Create vector table for SigLIP2 embeddings | |
| # SigLIP2 embeddings: 1152 dimensions (siglip2-so400m-patch16-naflex) | |
| conn.execute(""" | |
| CREATE VIRTUAL TABLE IF NOT EXISTS image_embeddings | |
| USING vec0( | |
| natural_key TEXT, | |
| frame_number INTEGER, | |
| embedding float[1152] | |
| ) | |
| """) | |
| # Create vector table for CLIP embeddings | |
| # CLIP embeddings: 512 dimensions (clip-vit-base-patch16) | |
| conn.execute(""" | |
| CREATE VIRTUAL TABLE IF NOT EXISTS clip_image_embeddings | |
| USING vec0( | |
| natural_key TEXT, | |
| frame_number INTEGER, | |
| embedding float[512] | |
| ) | |
| """) | |
| conn.commit() | |
| conn.close() | |
| def _get_db_connection(self) -> sqlite3.Connection: | |
| """ | |
| Get a database connection with sqlite-vec extension loaded. | |
| Returns: | |
| A SQLite connection object with sqlite-vec extension enabled | |
| """ | |
| conn = sqlite3.connect(self.db_path) | |
| conn.enable_load_extension(True) | |
| sqlite_vec.load(conn) | |
| conn.enable_load_extension(False) | |
| return conn | |
| # MARK: - Loading/Indexing | |
| def save_image_categories(self, natural_key: str, frame_number: int, | |
| categories: List[Dict[str, float]]) -> bool: | |
| """ | |
| Save image categories to the database. | |
| Args: | |
| natural_key: Video natural key (e.g., "pub-osg_108_VIDEO") | |
| frame_number: Frame number in the video | |
| categories: List of dicts with 'label' and 'score' keys | |
| Returns: | |
| True if saved successfully, False otherwise | |
| """ | |
| if not categories: | |
| return False | |
| conn = self._get_db_connection() | |
| try: | |
| # Insert each category as a separate row | |
| for category in categories: | |
| category_name = category.get('label', '') | |
| score = category.get('score', 0.0) | |
| conn.execute(""" | |
| INSERT INTO image_categories (natural_key, frame_number, category_name, score) | |
| VALUES (?, ?, ?, ?) | |
| """, (natural_key, frame_number, category_name, score)) | |
| conn.commit() | |
| return True | |
| except Exception as e: | |
| print(f"Error saving categories for {natural_key} frame {frame_number}: {e}") | |
| return False | |
| finally: | |
| conn.close() | |
| def save_image_embedding(self, natural_key: str, frame_number: int, | |
| embedding: np.ndarray, model: str = "siglip2") -> bool: | |
| """ | |
| Save image embedding to the database. | |
| Args: | |
| natural_key: Video natural key (e.g., "pub-osg_108_VIDEO") | |
| frame_number: Frame number in the video | |
| embedding: Image embedding vector (numpy array) | |
| model: Model type - "siglip2" or "clip" (default: "siglip2") | |
| Returns: | |
| True if saved successfully, False otherwise | |
| """ | |
| if embedding is None or len(embedding) == 0: | |
| return False | |
| conn = self._get_db_connection() | |
| try: | |
| # Convert embedding to JSON string for SQLite-Vec | |
| embedding_json = json.dumps(embedding.tolist() if isinstance(embedding, np.ndarray) else embedding) | |
| # Select table based on model type | |
| table_name = "clip_image_embeddings" if model == "clip" else "image_embeddings" | |
| # Insert or replace embedding | |
| conn.execute(f""" | |
| INSERT OR REPLACE INTO {table_name} (natural_key, frame_number, embedding) | |
| VALUES (?, ?, ?) | |
| """, (natural_key, frame_number, embedding_json)) | |
| conn.commit() | |
| return True | |
| except Exception as e: | |
| print(f"Error saving {model} embedding for {natural_key} frame {frame_number}: {e}") | |
| return False | |
| finally: | |
| conn.close() | |
| def remove_image_data(self, natural_key: str, frame_number: int = None) -> None: | |
| """ | |
| Remove image data (categories and embeddings) from database. | |
| Args: | |
| natural_key: Video natural key | |
| frame_number: Specific frame number, or None to remove all frames for the video | |
| """ | |
| conn = self._get_db_connection() | |
| try: | |
| if frame_number is not None: | |
| # Remove specific frame | |
| conn.execute("DELETE FROM image_categories WHERE natural_key = ? AND frame_number = ?", | |
| (natural_key, frame_number)) | |
| conn.execute("DELETE FROM image_embeddings WHERE natural_key = ? AND frame_number = ?", | |
| (natural_key, frame_number)) | |
| conn.execute("DELETE FROM clip_image_embeddings WHERE natural_key = ? AND frame_number = ?", | |
| (natural_key, frame_number)) | |
| else: | |
| # Remove all frames for this video | |
| conn.execute("DELETE FROM image_categories WHERE natural_key = ?", (natural_key,)) | |
| conn.execute("DELETE FROM image_embeddings WHERE natural_key = ?", (natural_key,)) | |
| conn.execute("DELETE FROM clip_image_embeddings WHERE natural_key = ?", (natural_key,)) | |
| conn.commit() | |
| finally: | |
| conn.close() | |
| # MARK: - Search Methods | |
| def search_by_category(self, query: str, limit: int = 20) -> List[Tuple[str, int, str, float]]: | |
| """ | |
| Search images by category name using FTS5 with fallback strategies. | |
| Args: | |
| query: Category search query (e.g., "person", "landscape") | |
| limit: Max results to return | |
| Returns: | |
| List of tuples: [(natural_key, frame_number, category_name, score)] | |
| """ | |
| conn = self._get_db_connection() | |
| results = [] | |
| # Strategy 1: Try normal FTS5 MATCH query first | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT natural_key, frame_number, category_name, score | |
| FROM image_categories | |
| WHERE category_name MATCH ? | |
| ORDER BY score DESC | |
| LIMIT ? | |
| """, (query, limit)) | |
| results = cursor.fetchall() | |
| except Exception as e: | |
| # Strategy 2: If normal query fails, try quote-escaped FTS5 query | |
| try: | |
| escaped_query = '"' + query.replace('"', '""') + '"' | |
| cursor = conn.execute(""" | |
| SELECT natural_key, frame_number, category_name, score | |
| FROM image_categories | |
| WHERE category_name MATCH ? | |
| ORDER BY score DESC | |
| LIMIT ? | |
| """, (escaped_query, limit)) | |
| results = cursor.fetchall() | |
| except Exception as e2: | |
| # Strategy 3: If FTS5 still fails, fall back to LIKE query | |
| cursor = conn.execute(""" | |
| SELECT natural_key, frame_number, category_name, score | |
| FROM image_categories | |
| WHERE category_name LIKE ? | |
| ORDER BY score DESC | |
| LIMIT ? | |
| """, (f"%{query}%", limit)) | |
| results = cursor.fetchall() | |
| conn.close() | |
| return results | |
| def search_by_embedding(self, query_text: str, limit: int = 20, model: str = "siglip2") -> List[Tuple[str, int, float]]: | |
| """ | |
| Search images by text query using text-to-image similarity. | |
| Args: | |
| query_text: Text description of desired image (e.g., "a sunset over mountains") | |
| limit: Max results to return | |
| model: Model to use - "siglip2" or "clip" (default: "siglip2") | |
| Returns: | |
| List of tuples: [(natural_key, frame_number, similarity_score)] | |
| """ | |
| start_time = time.time() | |
| # Generate text embedding using specified model | |
| text_embedding = get_text_embedding(query_text, model=model) | |
| text_embedding_json = json.dumps(text_embedding.tolist() if isinstance(text_embedding, np.ndarray) else text_embedding) | |
| # Search database using vector similarity | |
| conn = self._get_db_connection() | |
| # Select table based on model type | |
| table_name = "clip_image_embeddings" if model == "clip" else "image_embeddings" | |
| cursor = conn.execute(f""" | |
| SELECT natural_key, frame_number, distance | |
| FROM {table_name} | |
| WHERE embedding MATCH ? AND k = ? | |
| ORDER BY distance | |
| """, (text_embedding_json, limit)) | |
| results = [] | |
| for natural_key, frame_number, distance in cursor.fetchall(): | |
| # Convert distance to similarity score (normalize to [0,1] range) | |
| # Embeddings are normalized, so cosine distance is in [0,2] range | |
| similarity = 1.0 - (distance / 2.0) | |
| results.append((natural_key, frame_number, similarity)) | |
| conn.close() | |
| end_time = time.time() | |
| print(f"Time taken to search by {model} embedding: {end_time - start_time} seconds") | |
| return results | |
| def get_frame_categories(self, natural_key: str, frame_number: int) -> List[Dict[str, float]]: | |
| """ | |
| Get all categories for a specific frame. | |
| Args: | |
| natural_key: Video natural key | |
| frame_number: Frame number | |
| Returns: | |
| List of dicts with 'label' and 'score' keys | |
| """ | |
| conn = self._get_db_connection() | |
| cursor = conn.execute(""" | |
| SELECT category_name, score | |
| FROM image_categories | |
| WHERE natural_key = ? AND frame_number = ? | |
| ORDER BY score DESC | |
| """, (natural_key, frame_number)) | |
| results = [{'label': category_name, 'score': score} | |
| for category_name, score in cursor.fetchall()] | |
| conn.close() | |
| return results | |
| def get_top_categories(self, limit: int = 100) -> List[Tuple[str, int]]: | |
| """ | |
| Get the most common category labels from the database. | |
| Args: | |
| limit: Maximum number of categories to return (default: 100) | |
| Returns: | |
| List of tuples: [(category_name, count)] | |
| """ | |
| conn = self._get_db_connection() | |
| cursor = conn.execute(""" | |
| SELECT category_name, COUNT(*) as count | |
| FROM image_categories | |
| GROUP BY category_name | |
| ORDER BY count DESC | |
| LIMIT ? | |
| """, (limit,)) | |
| results = cursor.fetchall() | |
| conn.close() | |
| return results | |
| # MARK: - Add Images and Videos | |
| def add_image(self, image_path: str, model: str = "siglip2", include_categories: bool = True) -> bool: | |
| """ | |
| Add a single image to the database. | |
| Uses the absolute file path as the natural key and frame_number = 0. | |
| Args: | |
| image_path: Path to the image file | |
| model: Model to use - "siglip2" or "clip" (default: "siglip2") | |
| include_categories: Whether to generate and save image categories (default: True) | |
| Returns: | |
| True if successful, False otherwise | |
| """ | |
| # Convert to absolute path | |
| abs_path = os.path.abspath(image_path) | |
| if not os.path.exists(abs_path): | |
| print(f"Error: Image not found at {abs_path}") | |
| return False | |
| print(f"Processing image: {abs_path}") | |
| try: | |
| # Generate image embedding | |
| embedding = generate_image_embedding(abs_path, model=model) | |
| # Save embedding | |
| success = self.save_image_embedding(abs_path, 0, embedding, model=model) | |
| if not success: | |
| print(f"Failed to save embedding for {abs_path}") | |
| return False | |
| # Generate and save categories if requested | |
| if include_categories: | |
| categories = classify_image(abs_path) | |
| if categories: | |
| self.save_image_categories(abs_path, 0, categories) | |
| print(f"Saved {len(categories)} categories") | |
| print(f"Successfully added image: {abs_path}") | |
| return True | |
| except Exception as e: | |
| print(f"Error processing image {abs_path}: {e}") | |
| return False | |
| def add_video(self, video_path: str, model: str = "siglip2", interval_seconds: int = 5, include_categories: bool = True) -> bool: | |
| """ | |
| Add frames from a video to the database. | |
| Extracts frames at regular intervals and processes each one. | |
| Args: | |
| video_path: Path to the video file | |
| model: Model to use - "siglip2" or "clip" (default: "siglip2") | |
| interval_seconds: Extract a frame every N seconds (default: 5) | |
| include_categories: Whether to generate and save image categories (default: True) | |
| Returns: | |
| True if successful, False otherwise | |
| """ | |
| # Convert to absolute path | |
| abs_path = os.path.abspath(video_path) | |
| if not os.path.exists(abs_path): | |
| print(f"Error: Video not found at {abs_path}") | |
| return False | |
| print(f"Processing video: {abs_path}") | |
| try: | |
| # Get video duration using ffprobe | |
| duration_cmd = [ | |
| "ffprobe", | |
| "-v", "error", | |
| "-show_entries", "format=duration", | |
| "-of", "default=noprint_wrappers=1:nokey=1", | |
| abs_path | |
| ] | |
| duration_result = subprocess.run(duration_cmd, capture_output=True, text=True) | |
| if duration_result.returncode != 0: | |
| print(f"Error getting video duration: {duration_result.stderr}") | |
| return False | |
| duration = float(duration_result.stdout.strip()) | |
| print(f"Video duration: {duration:.2f} seconds") | |
| # Calculate number of frames to extract | |
| num_frames = int(duration / interval_seconds) | |
| print(f"Extracting {num_frames} frames at {interval_seconds}s intervals") | |
| # Create temporary directory for frames | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| frame_count = 0 | |
| # Extract frames at intervals | |
| for i in range(num_frames + 1): | |
| timestamp = i * interval_seconds | |
| # Skip if timestamp exceeds video duration | |
| if timestamp > duration: | |
| break | |
| # Extract frame at this timestamp | |
| frame_path = os.path.join(temp_dir, f"frame_{i:05d}.jpg") | |
| extract_cmd = [ | |
| "ffmpeg", | |
| "-ss", str(timestamp), | |
| "-i", abs_path, | |
| "-frames:v", "1", | |
| "-q:v", "2", | |
| "-y", | |
| frame_path | |
| ] | |
| result = subprocess.run(extract_cmd, capture_output=True, text=True) | |
| if result.returncode != 0: | |
| print(f"Warning: Failed to extract frame at {timestamp}s") | |
| continue | |
| # Process this frame | |
| print(f"Processing frame {i+1}/{num_frames+1} (at {timestamp}s)") | |
| # Generate embedding | |
| embedding = generate_image_embedding(frame_path, model=model) | |
| # Save embedding (frame_number = timestamp in seconds) | |
| self.save_image_embedding(abs_path, int(timestamp), embedding, model=model) | |
| # Generate and save categories if requested | |
| if include_categories: | |
| categories = classify_image(frame_path) | |
| if categories: | |
| self.save_image_categories(abs_path, int(timestamp), categories) | |
| frame_count += 1 | |
| print(f"Successfully processed {frame_count} frames from video: {abs_path}") | |
| return True | |
| except Exception as e: | |
| print(f"Error processing video {abs_path}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| # MARK: - Statistics | |
| def get_stats(self) -> dict: | |
| """ | |
| Get database statistics. | |
| Returns: | |
| Dict with stats: { | |
| total_categories: int, | |
| total_siglip2_embeddings: int, | |
| total_clip_embeddings: int, | |
| unique_videos: int, | |
| db_size_mb: float | |
| } | |
| """ | |
| conn = self._get_db_connection() | |
| # Count categories | |
| cursor = conn.execute("SELECT COUNT(*) FROM image_categories") | |
| total_categories = cursor.fetchone()[0] | |
| # Count SigLIP2 embeddings | |
| cursor = conn.execute("SELECT COUNT(*) FROM image_embeddings") | |
| total_siglip2_embeddings = cursor.fetchone()[0] | |
| # Count CLIP embeddings | |
| cursor = conn.execute("SELECT COUNT(*) FROM clip_image_embeddings") | |
| total_clip_embeddings = cursor.fetchone()[0] | |
| # Count unique videos (from SigLIP2 table) | |
| cursor = conn.execute("SELECT COUNT(DISTINCT natural_key) FROM image_embeddings") | |
| unique_videos = cursor.fetchone()[0] | |
| conn.close() | |
| # Get database file size | |
| db_size_mb = 0.0 | |
| if os.path.exists(self.db_path): | |
| db_size_mb = os.path.getsize(self.db_path) / (1024 * 1024) | |
| return { | |
| "total_categories": total_categories, | |
| "total_siglip2_embeddings": total_siglip2_embeddings, | |
| "total_clip_embeddings": total_clip_embeddings, | |
| "unique_videos": unique_videos, | |
| "db_size_mb": round(db_size_mb, 2) | |
| } | |
| # MARK: - CLI Interface | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Image Search System") | |
| subparsers = parser.add_subparsers(dest="command", help="Commands") | |
| # Add image command | |
| add_img_parser = subparsers.add_parser("add-image", help="Add a single image to the database") | |
| add_img_parser.add_argument("image_path", help="Path to the image file") | |
| add_img_parser.add_argument("--model", type=str, default="siglip2", choices=["siglip2", "clip"], help="Model to use (default: siglip2)") | |
| add_img_parser.add_argument("--no-categories", action="store_true", help="Skip category generation") | |
| # Add video command | |
| add_vid_parser = subparsers.add_parser("add-video", help="Add frames from a video to the database") | |
| add_vid_parser.add_argument("video_path", help="Path to the video file") | |
| add_vid_parser.add_argument("--model", type=str, default="siglip2", choices=["siglip2", "clip"], help="Model to use (default: siglip2)") | |
| add_vid_parser.add_argument("--interval", type=int, default=5, help="Extract a frame every N seconds (default: 5)") | |
| add_vid_parser.add_argument("--no-categories", action="store_true", help="Skip category generation") | |
| # Remove command | |
| remove_parser = subparsers.add_parser("remove", help="Remove all data for a specific file from the database") | |
| remove_parser.add_argument("file_path", help="Path to the file (will be converted to absolute path)") | |
| # Search by category command | |
| search_cat_parser = subparsers.add_parser("search-category", help="Search by category") | |
| search_cat_parser.add_argument("query", help="Category query (e.g., 'person')") | |
| search_cat_parser.add_argument("--limit", type=int, default=20, help="Max results") | |
| # Search by embedding command | |
| search_emb_parser = subparsers.add_parser("search-embedding", help="Search by text-to-image similarity") | |
| search_emb_parser.add_argument("query", help="Text description (e.g., 'sunset over mountains')") | |
| search_emb_parser.add_argument("--limit", type=int, default=20, help="Max results") | |
| search_emb_parser.add_argument("--model", type=str, default="siglip2", choices=["siglip2", "clip"], help="Model to use (default: siglip2)") | |
| # Stats command | |
| stats_parser = subparsers.add_parser("stats", help="Get database statistics") | |
| args = parser.parse_args() | |
| if not args.command: | |
| parser.print_help() | |
| exit(1) | |
| # Initialize search system | |
| search = ImageSearch() | |
| if args.command == "add-image": | |
| success = search.add_image( | |
| args.image_path, | |
| model=args.model, | |
| include_categories=not args.no_categories | |
| ) | |
| exit(0 if success else 1) | |
| elif args.command == "add-video": | |
| success = search.add_video( | |
| args.video_path, | |
| model=args.model, | |
| interval_seconds=args.interval, | |
| include_categories=not args.no_categories | |
| ) | |
| exit(0 if success else 1) | |
| elif args.command == "remove": | |
| # Convert to absolute path | |
| abs_path = os.path.abspath(args.file_path) | |
| print(f"Removing all data for: {abs_path}") | |
| search.remove_image_data(abs_path) | |
| print(f"Successfully removed all data for: {abs_path}") | |
| elif args.command == "search-category": | |
| results = search.search_by_category(args.query, args.limit) | |
| print(f"Found {len(results)} results for category: {args.query}") | |
| for natural_key, frame_number, category_name, score in results: | |
| print(f" {natural_key} @ frame {frame_number}: {category_name} (score: {score:.3f})") | |
| elif args.command == "search-embedding": | |
| results = search.search_by_embedding(args.query, args.limit, model=args.model) | |
| print(f"Found {len(results)} results for: {args.query} (using {args.model})") | |
| for natural_key, frame_number, similarity in results: | |
| print(f" {natural_key} @ frame {frame_number}: similarity {similarity:.3f}") | |
| elif args.command == "stats": | |
| stats = search.get_stats() | |
| print("Image Search Database Statistics:") | |
| print(f" Total categories: {stats['total_categories']}") | |
| print(f" Total SigLIP2 embeddings: {stats['total_siglip2_embeddings']}") | |
| print(f" Total CLIP embeddings: {stats['total_clip_embeddings']}") | |
| print(f" Unique videos: {stats['unique_videos']}") | |
| print(f" Database size: {stats['db_size_mb']} MB") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Simple demo showing how to generate and compare text and image embeddings. | |
| Uses SigLIP2 model for both text and image embeddings. | |
| """ | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModel | |
| import numpy as np | |
| import torch.nn.functional as F | |
| # Image paths | |
| IMAGE_PATHS = [ | |
| "./lake-smaller.jpg", | |
| "./cats.jpg", | |
| "./bmw-i7.jpg", | |
| "./beech-tree.jpg", | |
| ] | |
| # Global model instances (lazy loaded) | |
| siglip_model = None | |
| siglip_processor = None | |
| def get_device(): | |
| """Get the best available device (MPS for Mac, CUDA for NVIDIA, or CPU).""" | |
| if torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| elif torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| return torch.device("cpu") | |
| def get_siglip_model(): | |
| """Load and return the SigLIP2 model and processor.""" | |
| global siglip_model, siglip_processor | |
| if siglip_model is None: | |
| device = get_device() | |
| print(f"Loading SigLIP2 model on {device}...") | |
| siglip_model = AutoModel.from_pretrained("google/siglip2-so400m-patch16-naflex") | |
| siglip_model = siglip_model.to(device) | |
| siglip_processor = AutoProcessor.from_pretrained("google/siglip2-so400m-patch16-naflex") | |
| print("Model loaded!\n") | |
| return siglip_model, siglip_processor | |
| def generate_image_embedding(image_path): | |
| """ | |
| Generate an L2-normalized image embedding. | |
| Args: | |
| image_path: Path to the image file | |
| Returns: | |
| Tuple of (filename, embedding vector as numpy array) | |
| """ | |
| device = get_device() | |
| model, processor = get_siglip_model() | |
| # Load and convert image to RGB | |
| image = Image.open(image_path).convert("RGB") | |
| # Generate embedding | |
| with torch.no_grad(): | |
| inputs = processor(images=image, return_tensors="pt").to(device) | |
| features = model.get_image_features(**inputs) | |
| # L2 normalize for proper cosine similarity | |
| features = F.normalize(features, p=2, dim=-1) | |
| embedding = features.cpu().numpy()[0] | |
| # Extract just the filename from the path | |
| filename = image_path.split("/")[-1] | |
| return filename, embedding | |
| def generate_text_embedding(query_text): | |
| """ | |
| Generate an L2-normalized text embedding. | |
| Args: | |
| query_text: Text query string | |
| Returns: | |
| Embedding vector as numpy array | |
| """ | |
| device = get_device() | |
| model, processor = get_siglip_model() | |
| with torch.no_grad(): | |
| # Important: Use padding="max_length", max_length=64 for consistent embeddings | |
| inputs = processor(text=[query_text], return_tensors="pt", padding="max_length", max_length=64).to(device) | |
| features = model.get_text_features(**inputs) | |
| # L2 normalize | |
| features = F.normalize(features, p=2, dim=-1) | |
| embedding = features.cpu().numpy()[0] | |
| return embedding | |
| def cosine_similarity(vec1, vec2): | |
| """ | |
| Calculate cosine similarity between two vectors. | |
| Since vectors are already L2-normalized, this is just the dot product. | |
| Args: | |
| vec1, vec2: Numpy arrays | |
| Returns: | |
| Similarity score (higher is better, range 0-1) | |
| """ | |
| return np.dot(vec1, vec2) | |
| def load_images(): | |
| """ | |
| Load all images and generate their embeddings. | |
| Returns: | |
| List of tuples: [(filename, embedding), ...] | |
| """ | |
| print("Loading and embedding images...") | |
| embeddings = [] | |
| for image_path in IMAGE_PATHS: | |
| print(f" Processing {image_path.split('/')[-1]}...") | |
| filename, embedding = generate_image_embedding(image_path) | |
| embeddings.append((filename, embedding)) | |
| print(f"\nGenerated embeddings for {len(embeddings)} images\n") | |
| return embeddings | |
| def search(query_text, image_embeddings): | |
| """ | |
| Search for images matching the text query. | |
| Args: | |
| query_text: Text description to search for | |
| image_embeddings: List of (filename, embedding) tuples | |
| Returns: | |
| List of (filename, similarity_score) tuples sorted by similarity (best first) | |
| """ | |
| print(f"Searching for: '{query_text}'") | |
| # Generate text embedding | |
| text_embedding = generate_text_embedding(query_text) | |
| # Compare with all image embeddings | |
| results = [] | |
| for filename, image_embedding in image_embeddings: | |
| similarity = cosine_similarity(text_embedding, image_embedding) | |
| results.append((filename, similarity)) | |
| # Sort by similarity (highest first) | |
| results.sort(key=lambda x: x[1], reverse=True) | |
| print("\nResults:") | |
| print("-" * 50) | |
| for filename, similarity in results: | |
| print(f"{filename:30s} | Similarity: {similarity:.4f}") | |
| print("-" * 50) | |
| print() | |
| return results | |
| def main(): | |
| """Main demo function.""" | |
| print("=" * 50) | |
| print("Text and Image Embedding Demo") | |
| print("=" * 50) | |
| print() | |
| # Load all images and generate embeddings | |
| image_embeddings = load_images() | |
| # Run some example searches | |
| print("=" * 50) | |
| print("Example Searches") | |
| print("=" * 50) | |
| print() | |
| search("a beautiful lake with mountains", image_embeddings) | |
| search("cats sitting together", image_embeddings) | |
| search("luxury car", image_embeddings) | |
| search("tree in nature", image_embeddings) | |
| search("water and nature", image_embeddings) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment