Skip to content

Instantly share code, notes, and snippets.

@kylehowells
Created October 27, 2025 06:15
Show Gist options
  • Select an option

  • Save kylehowells/d7daed03d7d3b42ffa4bb9868168c439 to your computer and use it in GitHub Desktop.

Select an option

Save kylehowells/d7daed03d7d3b42ffa4bb9868168c439 to your computer and use it in GitHub Desktop.
Demo of using the OpenAI CLIP and SigLIP 2 models to do text and image embedding search.
"""
Simple demo showing how to generate and compare text and image embeddings.
Uses CLIP (openai/clip-vit-base-patch16) model for both text and image embeddings.
"""
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import numpy as np
import torch.nn.functional as F
# Image paths
IMAGE_PATHS = [
"./lake-smaller.jpg",
"./cats.jpg",
"./bmw-i7.jpg",
"./beech-tree.jpg",
]
# Global model instances (lazy loaded)
clip_model = None
clip_processor = None
def get_device():
"""Get the best available device (MPS for Mac, CUDA for NVIDIA, or CPU)."""
if torch.backends.mps.is_available():
return torch.device("mps")
elif torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def get_clip_model():
"""Load and return the CLIP model and processor."""
global clip_model, clip_processor
if clip_model is None:
device = get_device()
print(f"Loading CLIP model on {device}...")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
clip_model = clip_model.to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
print("Model loaded!\n")
return clip_model, clip_processor
def generate_image_embedding(image_path):
"""
Generate an L2-normalized image embedding.
Args:
image_path: Path to the image file
Returns:
Tuple of (filename, embedding vector as numpy array)
"""
device = get_device()
model, processor = get_clip_model()
# Load and convert image to RGB
image = Image.open(image_path).convert("RGB")
# Generate embedding
with torch.no_grad():
inputs = processor(images=image, return_tensors="pt").to(device)
features = model.get_image_features(**inputs)
# L2 normalize for proper cosine similarity
features = F.normalize(features, p=2, dim=-1)
embedding = features.cpu().numpy()[0]
# Extract just the filename from the path
filename = image_path.split("/")[-1]
return filename, embedding
def generate_text_embedding(query_text):
"""
Generate an L2-normalized text embedding.
Args:
query_text: Text query string
Returns:
Embedding vector as numpy array
"""
device = get_device()
model, processor = get_clip_model()
with torch.no_grad():
inputs = processor(text=[query_text], return_tensors="pt", padding=True).to(device)
features = model.get_text_features(**inputs)
# L2 normalize
features = F.normalize(features, p=2, dim=-1)
embedding = features.cpu().numpy()[0]
return embedding
def cosine_similarity(vec1, vec2):
"""
Calculate cosine similarity between two vectors.
Since vectors are already L2-normalized, this is just the dot product.
Args:
vec1, vec2: Numpy arrays
Returns:
Similarity score (higher is better, range 0-1)
"""
return np.dot(vec1, vec2)
def load_images():
"""
Load all images and generate their embeddings.
Returns:
List of tuples: [(filename, embedding), ...]
"""
print("Loading and embedding images...")
embeddings = []
for image_path in IMAGE_PATHS:
print(f" Processing {image_path.split('/')[-1]}...")
filename, embedding = generate_image_embedding(image_path)
embeddings.append((filename, embedding))
print(f"\nGenerated embeddings for {len(embeddings)} images\n")
return embeddings
def search(query_text, image_embeddings):
"""
Search for images matching the text query.
Args:
query_text: Text description to search for
image_embeddings: List of (filename, embedding) tuples
Returns:
List of (filename, similarity_score) tuples sorted by similarity (best first)
"""
print(f"Searching for: '{query_text}'")
# Generate text embedding
text_embedding = generate_text_embedding(query_text)
# Compare with all image embeddings
results = []
for filename, image_embedding in image_embeddings:
similarity = cosine_similarity(text_embedding, image_embedding)
results.append((filename, similarity))
# Sort by similarity (highest first)
results.sort(key=lambda x: x[1], reverse=True)
print("\nResults:")
print("-" * 50)
for filename, similarity in results:
print(f"{filename:30s} | Similarity: {similarity:.4f}")
print("-" * 50)
print()
return results
def main():
"""Main demo function."""
print("=" * 50)
print("Text and Image Embedding Demo (CLIP)")
print("=" * 50)
print()
# Load all images and generate embeddings
image_embeddings = load_images()
# Run some example searches
print("=" * 50)
print("Example Searches")
print("=" * 50)
print()
search("a beautiful lake with mountains", image_embeddings)
search("cats sitting together", image_embeddings)
search("luxury car", image_embeddings)
search("tree in nature", image_embeddings)
search("water and nature", image_embeddings)
if __name__ == "__main__":
main()
torch
torchvision
transformers
pillow
numpy
sqlite-vec
"""
Search based on images categories and image embeddings saved in a local SQLite database.
Each frame of a video will have multiple categories.
However, it will only have one embedding vector.
Image categories are saved with a video natural key, timestamp/frame number and a category name.
Image embeddings are saved with a video natural key, timestamp/frame number and an embedding vector.
The image search system will use either the image categories and embeddings to search for images that match the query.
You provide a query, which is either matched against the image categories or tokenized and matched against the image embeddings using the cosine similarity score.
The image search system will return a list of images that match the query.
Each image will have the following information:
- The video natural key (which video is this from)
- The frame number (which frame is this from)
Then either the category names if searching by category.
Or the cosine similarity score if searching by embedding.
"""
import os
# Use pysqlite3 instead of built-in sqlite3 for extension support
try:
import pysqlite3 as sqlite3
except ImportError:
import sqlite3
import sqlite_vec
import json
from typing import List, Tuple, Optional
import argparse
from transformers import pipeline
from typing import Dict
import time
import subprocess
import tempfile
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModel, CLIPProcessor, CLIPModel
import numpy as np
import torch.nn.functional as F
clf = None
def classify_image(image_path: str) -> List[Dict[str, float]]:
"""
Classify an image using a pre-trained model.
"""
global clf
if clf is None:
device = get_device()
# Map device to pipeline device format (pipeline uses -1 for cpu, 0+ for gpu indices)
if device.type == "cuda":
pipeline_device = 0
elif device.type == "mps":
pipeline_device = 0 # MPS also uses 0
else:
pipeline_device = -1 # CPU
clf = pipeline("image-classification", "google/vit-large-patch16-224", device=pipeline_device, use_fast=True)
start_time = time.time()
# Classify the image
results = clf(image_path)
# Filter out results with score less than 0.1
results = [result for result in results if result["score"] > 0.1]
end_time = time.time()
print(f"Time taken to classify image: {end_time - start_time} seconds")
# Return the results
return results
# Model instances (lazy loaded)
siglip_model = None
siglip_processor = None
clip_model = None
clip_processor = None
def get_siglip_model() -> Tuple[AutoModel, AutoProcessor]:
"""
Get the SigLIP2 model and processor.
Uses AutoProcessor for both images and text (correct approach from reference projects).
"""
global siglip_model, siglip_processor
if siglip_model is None:
device = get_device()
siglip_model = AutoModel.from_pretrained("google/siglip2-so400m-patch16-naflex")
siglip_model = siglip_model.to(device)
siglip_processor = AutoProcessor.from_pretrained("google/siglip2-so400m-patch16-naflex")
return siglip_model, siglip_processor
def get_clip_model() -> Tuple[CLIPModel, CLIPProcessor]:
"""
Get the CLIP model and processor.
"""
global clip_model, clip_processor
if clip_model is None:
device = get_device()
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
clip_model = clip_model.to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
return clip_model, clip_processor
def get_device() -> torch.device:
"""
Get the device to use for the model.
"""
if torch.backends.mps.is_available():
return torch.device("mps")
elif torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def generate_image_embedding(image_path: str, model: str = "siglip2") -> np.ndarray:
"""
Generate an image embedding using either SigLIP2 or CLIP.
Embeddings are L2-normalized for proper cosine similarity comparison.
Args:
image_path: Path to the image file
model: Model to use - "siglip2" or "clip" (default: "siglip2")
Returns:
L2-normalized embedding vector as numpy array
"""
device = get_device()
image = Image.open(image_path).convert("RGB")
start_time = time.time()
if model == "clip":
clip_model, clip_processor = get_clip_model()
with torch.no_grad():
inputs = clip_processor(images=image, return_tensors="pt").to(device)
features = clip_model.get_image_features(**inputs)
# L2 normalize
features = F.normalize(features, p=2, dim=-1)
image_embedding = features.cpu().numpy()[0]
else: # siglip2
siglip_model, siglip_processor = get_siglip_model()
with torch.no_grad():
inputs = siglip_processor(images=image, return_tensors="pt").to(device)
features = siglip_model.get_image_features(**inputs)
# L2 normalize
features = F.normalize(features, p=2, dim=-1)
image_embedding = features.cpu().numpy()[0]
end_time = time.time()
print(f"Time taken to generate {model} image embedding: {end_time - start_time} seconds")
return image_embedding
def get_text_embedding(query_text: str, model: str = "siglip2") -> np.ndarray:
"""
Generate a text embedding to query against the image embeddings.
Embeddings are L2-normalized for proper cosine similarity comparison.
Args:
query_text: Text query string
model: Model to use - "siglip2" or "clip" (default: "siglip2")
Returns:
L2-normalized embedding vector as numpy array
"""
device = get_device()
start_time = time.time()
if model == "clip":
clip_model, clip_processor = get_clip_model()
with torch.no_grad():
inputs = clip_processor(text=[query_text], return_tensors="pt", padding=True).to(device)
features = clip_model.get_text_features(**inputs)
# L2 normalize
features = F.normalize(features, p=2, dim=-1)
query_text_embedding = features.cpu().numpy()[0]
else: # siglip2
siglip_model, siglip_processor = get_siglip_model()
with torch.no_grad():
# Use padding="max_length", max_length=64
# This is important! Using just padding=True produces very different embeddings
# (cosine similarity ~0.7 vs 1.0). Two reference projects use max_length=64.
inputs = siglip_processor(text=[query_text], return_tensors="pt", padding="max_length", max_length=64).to(device)
features = siglip_model.get_text_features(**inputs)
# L2 normalize
features = F.normalize(features, p=2, dim=-1)
query_text_embedding = features.cpu().numpy()[0]
end_time = time.time()
print(f"Time taken to generate {model} text embedding: {end_time - start_time} seconds")
return query_text_embedding
# MARK: - Image Search Database
class ImageSearch:
"""Image search system using FTS5 for categories and vector embeddings."""
db_path: str
def __init__(self, db_path: str = None):
"""
Initialize the image search system.
Args:
db_path: Path to the SQLite database file
"""
self.db_path = db_path if db_path else os.path.join(os.path.dirname(__file__), "images_database.db")
self._init_db()
def _init_db(self) -> None:
"""Initialize the SQLite database with FTS5 and vector tables."""
conn = sqlite3.connect(self.db_path)
conn.enable_load_extension(True)
sqlite_vec.load(conn)
conn.enable_load_extension(False)
# Create FTS5 virtual table for category search
# Categories are searchable text labels (e.g., "person", "landscape", "indoor")
conn.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS image_categories
USING fts5(
natural_key,
frame_number,
category_name,
score
)
""")
# Create vector table for SigLIP2 embeddings
# SigLIP2 embeddings: 1152 dimensions (siglip2-so400m-patch16-naflex)
conn.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS image_embeddings
USING vec0(
natural_key TEXT,
frame_number INTEGER,
embedding float[1152]
)
""")
# Create vector table for CLIP embeddings
# CLIP embeddings: 512 dimensions (clip-vit-base-patch16)
conn.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS clip_image_embeddings
USING vec0(
natural_key TEXT,
frame_number INTEGER,
embedding float[512]
)
""")
conn.commit()
conn.close()
def _get_db_connection(self) -> sqlite3.Connection:
"""
Get a database connection with sqlite-vec extension loaded.
Returns:
A SQLite connection object with sqlite-vec extension enabled
"""
conn = sqlite3.connect(self.db_path)
conn.enable_load_extension(True)
sqlite_vec.load(conn)
conn.enable_load_extension(False)
return conn
# MARK: - Loading/Indexing
def save_image_categories(self, natural_key: str, frame_number: int,
categories: List[Dict[str, float]]) -> bool:
"""
Save image categories to the database.
Args:
natural_key: Video natural key (e.g., "pub-osg_108_VIDEO")
frame_number: Frame number in the video
categories: List of dicts with 'label' and 'score' keys
Returns:
True if saved successfully, False otherwise
"""
if not categories:
return False
conn = self._get_db_connection()
try:
# Insert each category as a separate row
for category in categories:
category_name = category.get('label', '')
score = category.get('score', 0.0)
conn.execute("""
INSERT INTO image_categories (natural_key, frame_number, category_name, score)
VALUES (?, ?, ?, ?)
""", (natural_key, frame_number, category_name, score))
conn.commit()
return True
except Exception as e:
print(f"Error saving categories for {natural_key} frame {frame_number}: {e}")
return False
finally:
conn.close()
def save_image_embedding(self, natural_key: str, frame_number: int,
embedding: np.ndarray, model: str = "siglip2") -> bool:
"""
Save image embedding to the database.
Args:
natural_key: Video natural key (e.g., "pub-osg_108_VIDEO")
frame_number: Frame number in the video
embedding: Image embedding vector (numpy array)
model: Model type - "siglip2" or "clip" (default: "siglip2")
Returns:
True if saved successfully, False otherwise
"""
if embedding is None or len(embedding) == 0:
return False
conn = self._get_db_connection()
try:
# Convert embedding to JSON string for SQLite-Vec
embedding_json = json.dumps(embedding.tolist() if isinstance(embedding, np.ndarray) else embedding)
# Select table based on model type
table_name = "clip_image_embeddings" if model == "clip" else "image_embeddings"
# Insert or replace embedding
conn.execute(f"""
INSERT OR REPLACE INTO {table_name} (natural_key, frame_number, embedding)
VALUES (?, ?, ?)
""", (natural_key, frame_number, embedding_json))
conn.commit()
return True
except Exception as e:
print(f"Error saving {model} embedding for {natural_key} frame {frame_number}: {e}")
return False
finally:
conn.close()
def remove_image_data(self, natural_key: str, frame_number: int = None) -> None:
"""
Remove image data (categories and embeddings) from database.
Args:
natural_key: Video natural key
frame_number: Specific frame number, or None to remove all frames for the video
"""
conn = self._get_db_connection()
try:
if frame_number is not None:
# Remove specific frame
conn.execute("DELETE FROM image_categories WHERE natural_key = ? AND frame_number = ?",
(natural_key, frame_number))
conn.execute("DELETE FROM image_embeddings WHERE natural_key = ? AND frame_number = ?",
(natural_key, frame_number))
conn.execute("DELETE FROM clip_image_embeddings WHERE natural_key = ? AND frame_number = ?",
(natural_key, frame_number))
else:
# Remove all frames for this video
conn.execute("DELETE FROM image_categories WHERE natural_key = ?", (natural_key,))
conn.execute("DELETE FROM image_embeddings WHERE natural_key = ?", (natural_key,))
conn.execute("DELETE FROM clip_image_embeddings WHERE natural_key = ?", (natural_key,))
conn.commit()
finally:
conn.close()
# MARK: - Search Methods
def search_by_category(self, query: str, limit: int = 20) -> List[Tuple[str, int, str, float]]:
"""
Search images by category name using FTS5 with fallback strategies.
Args:
query: Category search query (e.g., "person", "landscape")
limit: Max results to return
Returns:
List of tuples: [(natural_key, frame_number, category_name, score)]
"""
conn = self._get_db_connection()
results = []
# Strategy 1: Try normal FTS5 MATCH query first
try:
cursor = conn.execute("""
SELECT natural_key, frame_number, category_name, score
FROM image_categories
WHERE category_name MATCH ?
ORDER BY score DESC
LIMIT ?
""", (query, limit))
results = cursor.fetchall()
except Exception as e:
# Strategy 2: If normal query fails, try quote-escaped FTS5 query
try:
escaped_query = '"' + query.replace('"', '""') + '"'
cursor = conn.execute("""
SELECT natural_key, frame_number, category_name, score
FROM image_categories
WHERE category_name MATCH ?
ORDER BY score DESC
LIMIT ?
""", (escaped_query, limit))
results = cursor.fetchall()
except Exception as e2:
# Strategy 3: If FTS5 still fails, fall back to LIKE query
cursor = conn.execute("""
SELECT natural_key, frame_number, category_name, score
FROM image_categories
WHERE category_name LIKE ?
ORDER BY score DESC
LIMIT ?
""", (f"%{query}%", limit))
results = cursor.fetchall()
conn.close()
return results
def search_by_embedding(self, query_text: str, limit: int = 20, model: str = "siglip2") -> List[Tuple[str, int, float]]:
"""
Search images by text query using text-to-image similarity.
Args:
query_text: Text description of desired image (e.g., "a sunset over mountains")
limit: Max results to return
model: Model to use - "siglip2" or "clip" (default: "siglip2")
Returns:
List of tuples: [(natural_key, frame_number, similarity_score)]
"""
start_time = time.time()
# Generate text embedding using specified model
text_embedding = get_text_embedding(query_text, model=model)
text_embedding_json = json.dumps(text_embedding.tolist() if isinstance(text_embedding, np.ndarray) else text_embedding)
# Search database using vector similarity
conn = self._get_db_connection()
# Select table based on model type
table_name = "clip_image_embeddings" if model == "clip" else "image_embeddings"
cursor = conn.execute(f"""
SELECT natural_key, frame_number, distance
FROM {table_name}
WHERE embedding MATCH ? AND k = ?
ORDER BY distance
""", (text_embedding_json, limit))
results = []
for natural_key, frame_number, distance in cursor.fetchall():
# Convert distance to similarity score (normalize to [0,1] range)
# Embeddings are normalized, so cosine distance is in [0,2] range
similarity = 1.0 - (distance / 2.0)
results.append((natural_key, frame_number, similarity))
conn.close()
end_time = time.time()
print(f"Time taken to search by {model} embedding: {end_time - start_time} seconds")
return results
def get_frame_categories(self, natural_key: str, frame_number: int) -> List[Dict[str, float]]:
"""
Get all categories for a specific frame.
Args:
natural_key: Video natural key
frame_number: Frame number
Returns:
List of dicts with 'label' and 'score' keys
"""
conn = self._get_db_connection()
cursor = conn.execute("""
SELECT category_name, score
FROM image_categories
WHERE natural_key = ? AND frame_number = ?
ORDER BY score DESC
""", (natural_key, frame_number))
results = [{'label': category_name, 'score': score}
for category_name, score in cursor.fetchall()]
conn.close()
return results
def get_top_categories(self, limit: int = 100) -> List[Tuple[str, int]]:
"""
Get the most common category labels from the database.
Args:
limit: Maximum number of categories to return (default: 100)
Returns:
List of tuples: [(category_name, count)]
"""
conn = self._get_db_connection()
cursor = conn.execute("""
SELECT category_name, COUNT(*) as count
FROM image_categories
GROUP BY category_name
ORDER BY count DESC
LIMIT ?
""", (limit,))
results = cursor.fetchall()
conn.close()
return results
# MARK: - Add Images and Videos
def add_image(self, image_path: str, model: str = "siglip2", include_categories: bool = True) -> bool:
"""
Add a single image to the database.
Uses the absolute file path as the natural key and frame_number = 0.
Args:
image_path: Path to the image file
model: Model to use - "siglip2" or "clip" (default: "siglip2")
include_categories: Whether to generate and save image categories (default: True)
Returns:
True if successful, False otherwise
"""
# Convert to absolute path
abs_path = os.path.abspath(image_path)
if not os.path.exists(abs_path):
print(f"Error: Image not found at {abs_path}")
return False
print(f"Processing image: {abs_path}")
try:
# Generate image embedding
embedding = generate_image_embedding(abs_path, model=model)
# Save embedding
success = self.save_image_embedding(abs_path, 0, embedding, model=model)
if not success:
print(f"Failed to save embedding for {abs_path}")
return False
# Generate and save categories if requested
if include_categories:
categories = classify_image(abs_path)
if categories:
self.save_image_categories(abs_path, 0, categories)
print(f"Saved {len(categories)} categories")
print(f"Successfully added image: {abs_path}")
return True
except Exception as e:
print(f"Error processing image {abs_path}: {e}")
return False
def add_video(self, video_path: str, model: str = "siglip2", interval_seconds: int = 5, include_categories: bool = True) -> bool:
"""
Add frames from a video to the database.
Extracts frames at regular intervals and processes each one.
Args:
video_path: Path to the video file
model: Model to use - "siglip2" or "clip" (default: "siglip2")
interval_seconds: Extract a frame every N seconds (default: 5)
include_categories: Whether to generate and save image categories (default: True)
Returns:
True if successful, False otherwise
"""
# Convert to absolute path
abs_path = os.path.abspath(video_path)
if not os.path.exists(abs_path):
print(f"Error: Video not found at {abs_path}")
return False
print(f"Processing video: {abs_path}")
try:
# Get video duration using ffprobe
duration_cmd = [
"ffprobe",
"-v", "error",
"-show_entries", "format=duration",
"-of", "default=noprint_wrappers=1:nokey=1",
abs_path
]
duration_result = subprocess.run(duration_cmd, capture_output=True, text=True)
if duration_result.returncode != 0:
print(f"Error getting video duration: {duration_result.stderr}")
return False
duration = float(duration_result.stdout.strip())
print(f"Video duration: {duration:.2f} seconds")
# Calculate number of frames to extract
num_frames = int(duration / interval_seconds)
print(f"Extracting {num_frames} frames at {interval_seconds}s intervals")
# Create temporary directory for frames
with tempfile.TemporaryDirectory() as temp_dir:
frame_count = 0
# Extract frames at intervals
for i in range(num_frames + 1):
timestamp = i * interval_seconds
# Skip if timestamp exceeds video duration
if timestamp > duration:
break
# Extract frame at this timestamp
frame_path = os.path.join(temp_dir, f"frame_{i:05d}.jpg")
extract_cmd = [
"ffmpeg",
"-ss", str(timestamp),
"-i", abs_path,
"-frames:v", "1",
"-q:v", "2",
"-y",
frame_path
]
result = subprocess.run(extract_cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"Warning: Failed to extract frame at {timestamp}s")
continue
# Process this frame
print(f"Processing frame {i+1}/{num_frames+1} (at {timestamp}s)")
# Generate embedding
embedding = generate_image_embedding(frame_path, model=model)
# Save embedding (frame_number = timestamp in seconds)
self.save_image_embedding(abs_path, int(timestamp), embedding, model=model)
# Generate and save categories if requested
if include_categories:
categories = classify_image(frame_path)
if categories:
self.save_image_categories(abs_path, int(timestamp), categories)
frame_count += 1
print(f"Successfully processed {frame_count} frames from video: {abs_path}")
return True
except Exception as e:
print(f"Error processing video {abs_path}: {e}")
import traceback
traceback.print_exc()
return False
# MARK: - Statistics
def get_stats(self) -> dict:
"""
Get database statistics.
Returns:
Dict with stats: {
total_categories: int,
total_siglip2_embeddings: int,
total_clip_embeddings: int,
unique_videos: int,
db_size_mb: float
}
"""
conn = self._get_db_connection()
# Count categories
cursor = conn.execute("SELECT COUNT(*) FROM image_categories")
total_categories = cursor.fetchone()[0]
# Count SigLIP2 embeddings
cursor = conn.execute("SELECT COUNT(*) FROM image_embeddings")
total_siglip2_embeddings = cursor.fetchone()[0]
# Count CLIP embeddings
cursor = conn.execute("SELECT COUNT(*) FROM clip_image_embeddings")
total_clip_embeddings = cursor.fetchone()[0]
# Count unique videos (from SigLIP2 table)
cursor = conn.execute("SELECT COUNT(DISTINCT natural_key) FROM image_embeddings")
unique_videos = cursor.fetchone()[0]
conn.close()
# Get database file size
db_size_mb = 0.0
if os.path.exists(self.db_path):
db_size_mb = os.path.getsize(self.db_path) / (1024 * 1024)
return {
"total_categories": total_categories,
"total_siglip2_embeddings": total_siglip2_embeddings,
"total_clip_embeddings": total_clip_embeddings,
"unique_videos": unique_videos,
"db_size_mb": round(db_size_mb, 2)
}
# MARK: - CLI Interface
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Image Search System")
subparsers = parser.add_subparsers(dest="command", help="Commands")
# Add image command
add_img_parser = subparsers.add_parser("add-image", help="Add a single image to the database")
add_img_parser.add_argument("image_path", help="Path to the image file")
add_img_parser.add_argument("--model", type=str, default="siglip2", choices=["siglip2", "clip"], help="Model to use (default: siglip2)")
add_img_parser.add_argument("--no-categories", action="store_true", help="Skip category generation")
# Add video command
add_vid_parser = subparsers.add_parser("add-video", help="Add frames from a video to the database")
add_vid_parser.add_argument("video_path", help="Path to the video file")
add_vid_parser.add_argument("--model", type=str, default="siglip2", choices=["siglip2", "clip"], help="Model to use (default: siglip2)")
add_vid_parser.add_argument("--interval", type=int, default=5, help="Extract a frame every N seconds (default: 5)")
add_vid_parser.add_argument("--no-categories", action="store_true", help="Skip category generation")
# Remove command
remove_parser = subparsers.add_parser("remove", help="Remove all data for a specific file from the database")
remove_parser.add_argument("file_path", help="Path to the file (will be converted to absolute path)")
# Search by category command
search_cat_parser = subparsers.add_parser("search-category", help="Search by category")
search_cat_parser.add_argument("query", help="Category query (e.g., 'person')")
search_cat_parser.add_argument("--limit", type=int, default=20, help="Max results")
# Search by embedding command
search_emb_parser = subparsers.add_parser("search-embedding", help="Search by text-to-image similarity")
search_emb_parser.add_argument("query", help="Text description (e.g., 'sunset over mountains')")
search_emb_parser.add_argument("--limit", type=int, default=20, help="Max results")
search_emb_parser.add_argument("--model", type=str, default="siglip2", choices=["siglip2", "clip"], help="Model to use (default: siglip2)")
# Stats command
stats_parser = subparsers.add_parser("stats", help="Get database statistics")
args = parser.parse_args()
if not args.command:
parser.print_help()
exit(1)
# Initialize search system
search = ImageSearch()
if args.command == "add-image":
success = search.add_image(
args.image_path,
model=args.model,
include_categories=not args.no_categories
)
exit(0 if success else 1)
elif args.command == "add-video":
success = search.add_video(
args.video_path,
model=args.model,
interval_seconds=args.interval,
include_categories=not args.no_categories
)
exit(0 if success else 1)
elif args.command == "remove":
# Convert to absolute path
abs_path = os.path.abspath(args.file_path)
print(f"Removing all data for: {abs_path}")
search.remove_image_data(abs_path)
print(f"Successfully removed all data for: {abs_path}")
elif args.command == "search-category":
results = search.search_by_category(args.query, args.limit)
print(f"Found {len(results)} results for category: {args.query}")
for natural_key, frame_number, category_name, score in results:
print(f" {natural_key} @ frame {frame_number}: {category_name} (score: {score:.3f})")
elif args.command == "search-embedding":
results = search.search_by_embedding(args.query, args.limit, model=args.model)
print(f"Found {len(results)} results for: {args.query} (using {args.model})")
for natural_key, frame_number, similarity in results:
print(f" {natural_key} @ frame {frame_number}: similarity {similarity:.3f}")
elif args.command == "stats":
stats = search.get_stats()
print("Image Search Database Statistics:")
print(f" Total categories: {stats['total_categories']}")
print(f" Total SigLIP2 embeddings: {stats['total_siglip2_embeddings']}")
print(f" Total CLIP embeddings: {stats['total_clip_embeddings']}")
print(f" Unique videos: {stats['unique_videos']}")
print(f" Database size: {stats['db_size_mb']} MB")
"""
Simple demo showing how to generate and compare text and image embeddings.
Uses SigLIP2 model for both text and image embeddings.
"""
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModel
import numpy as np
import torch.nn.functional as F
# Image paths
IMAGE_PATHS = [
"./lake-smaller.jpg",
"./cats.jpg",
"./bmw-i7.jpg",
"./beech-tree.jpg",
]
# Global model instances (lazy loaded)
siglip_model = None
siglip_processor = None
def get_device():
"""Get the best available device (MPS for Mac, CUDA for NVIDIA, or CPU)."""
if torch.backends.mps.is_available():
return torch.device("mps")
elif torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def get_siglip_model():
"""Load and return the SigLIP2 model and processor."""
global siglip_model, siglip_processor
if siglip_model is None:
device = get_device()
print(f"Loading SigLIP2 model on {device}...")
siglip_model = AutoModel.from_pretrained("google/siglip2-so400m-patch16-naflex")
siglip_model = siglip_model.to(device)
siglip_processor = AutoProcessor.from_pretrained("google/siglip2-so400m-patch16-naflex")
print("Model loaded!\n")
return siglip_model, siglip_processor
def generate_image_embedding(image_path):
"""
Generate an L2-normalized image embedding.
Args:
image_path: Path to the image file
Returns:
Tuple of (filename, embedding vector as numpy array)
"""
device = get_device()
model, processor = get_siglip_model()
# Load and convert image to RGB
image = Image.open(image_path).convert("RGB")
# Generate embedding
with torch.no_grad():
inputs = processor(images=image, return_tensors="pt").to(device)
features = model.get_image_features(**inputs)
# L2 normalize for proper cosine similarity
features = F.normalize(features, p=2, dim=-1)
embedding = features.cpu().numpy()[0]
# Extract just the filename from the path
filename = image_path.split("/")[-1]
return filename, embedding
def generate_text_embedding(query_text):
"""
Generate an L2-normalized text embedding.
Args:
query_text: Text query string
Returns:
Embedding vector as numpy array
"""
device = get_device()
model, processor = get_siglip_model()
with torch.no_grad():
# Important: Use padding="max_length", max_length=64 for consistent embeddings
inputs = processor(text=[query_text], return_tensors="pt", padding="max_length", max_length=64).to(device)
features = model.get_text_features(**inputs)
# L2 normalize
features = F.normalize(features, p=2, dim=-1)
embedding = features.cpu().numpy()[0]
return embedding
def cosine_similarity(vec1, vec2):
"""
Calculate cosine similarity between two vectors.
Since vectors are already L2-normalized, this is just the dot product.
Args:
vec1, vec2: Numpy arrays
Returns:
Similarity score (higher is better, range 0-1)
"""
return np.dot(vec1, vec2)
def load_images():
"""
Load all images and generate their embeddings.
Returns:
List of tuples: [(filename, embedding), ...]
"""
print("Loading and embedding images...")
embeddings = []
for image_path in IMAGE_PATHS:
print(f" Processing {image_path.split('/')[-1]}...")
filename, embedding = generate_image_embedding(image_path)
embeddings.append((filename, embedding))
print(f"\nGenerated embeddings for {len(embeddings)} images\n")
return embeddings
def search(query_text, image_embeddings):
"""
Search for images matching the text query.
Args:
query_text: Text description to search for
image_embeddings: List of (filename, embedding) tuples
Returns:
List of (filename, similarity_score) tuples sorted by similarity (best first)
"""
print(f"Searching for: '{query_text}'")
# Generate text embedding
text_embedding = generate_text_embedding(query_text)
# Compare with all image embeddings
results = []
for filename, image_embedding in image_embeddings:
similarity = cosine_similarity(text_embedding, image_embedding)
results.append((filename, similarity))
# Sort by similarity (highest first)
results.sort(key=lambda x: x[1], reverse=True)
print("\nResults:")
print("-" * 50)
for filename, similarity in results:
print(f"{filename:30s} | Similarity: {similarity:.4f}")
print("-" * 50)
print()
return results
def main():
"""Main demo function."""
print("=" * 50)
print("Text and Image Embedding Demo")
print("=" * 50)
print()
# Load all images and generate embeddings
image_embeddings = load_images()
# Run some example searches
print("=" * 50)
print("Example Searches")
print("=" * 50)
print()
search("a beautiful lake with mountains", image_embeddings)
search("cats sitting together", image_embeddings)
search("luxury car", image_embeddings)
search("tree in nature", image_embeddings)
search("water and nature", image_embeddings)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment