Skip to content

Instantly share code, notes, and snippets.

@geobabbler
Created July 19, 2024 11:09
Show Gist options
  • Save geobabbler/922feeaf74c169fa2ea4b47fa669a816 to your computer and use it in GitHub Desktop.
Save geobabbler/922feeaf74c169fa2ea4b47fa669a816 to your computer and use it in GitHub Desktop.
Script to perform image similarity search using pgvector
import os
import psycopg2
import numpy as np
from PIL import Image
from scipy.spatial.distance import cosine
import torch
from torchvision import models, transforms
# Database configuration
DB_NAME = "database"
DB_USER = "user"
DB_PASSWORD = "password"
DB_HOST = "host"
DB_PORT = "port"
# Load pre-trained ResNet model
model = models.resnet50(pretrained=True)
model.eval() # Set to evaluation mode
# Remove the final classification layer to get embeddings
model = torch.nn.Sequential(*list(model.children())[:-1])
# Image preprocessing
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def get_image_embedding(image_path):
image = Image.open(image_path).convert("RGB")
image_tensor = preprocess(image).unsqueeze(0)
with torch.no_grad():
embedding = model(image_tensor)
return embedding.squeeze().numpy()
def find_most_similar_images(query_embedding, top_n=3):
conn = psycopg2.connect(
dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD, host=DB_HOST, port=DB_PORT
)
cursor = conn.cursor()
# Convert query_embedding to a list to store in PostgreSQL
query_embedding_list = query_embedding.tolist()
qry = f"SELECT image_path, embedding <-> '{query_embedding_list}' AS distance FROM image_embeddings ORDER BY embedding <-> '{query_embedding_list}' LIMIT {top_n};"
#print(qry)
# Execute the similarity search query
cursor.execute(
qry
)
similar_images = cursor.fetchall()
cursor.close()
conn.close()
return similar_images
def main(query_image_path):
query_embedding = get_image_embedding(query_image_path)
similar_images = find_most_similar_images(query_embedding)
print("Most similar images:")
for image_path, similarity in similar_images:
print(f"Image: {image_path}, Distance: {similarity}")
if __name__ == "__main__":
# Replace with the path to the query image
query_image_path = "path/to/your/query/image.jpg"
main(query_image_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment