Created
July 19, 2024 11:09
-
-
Save geobabbler/922feeaf74c169fa2ea4b47fa669a816 to your computer and use it in GitHub Desktop.
Script to perform image similarity search using pgvector
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import psycopg2 | |
import numpy as np | |
from PIL import Image | |
from scipy.spatial.distance import cosine | |
import torch | |
from torchvision import models, transforms | |
# Database configuration | |
DB_NAME = "database" | |
DB_USER = "user" | |
DB_PASSWORD = "password" | |
DB_HOST = "host" | |
DB_PORT = "port" | |
# Load pre-trained ResNet model | |
model = models.resnet50(pretrained=True) | |
model.eval() # Set to evaluation mode | |
# Remove the final classification layer to get embeddings | |
model = torch.nn.Sequential(*list(model.children())[:-1]) | |
# Image preprocessing | |
preprocess = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
def get_image_embedding(image_path): | |
image = Image.open(image_path).convert("RGB") | |
image_tensor = preprocess(image).unsqueeze(0) | |
with torch.no_grad(): | |
embedding = model(image_tensor) | |
return embedding.squeeze().numpy() | |
def find_most_similar_images(query_embedding, top_n=3): | |
conn = psycopg2.connect( | |
dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD, host=DB_HOST, port=DB_PORT | |
) | |
cursor = conn.cursor() | |
# Convert query_embedding to a list to store in PostgreSQL | |
query_embedding_list = query_embedding.tolist() | |
qry = f"SELECT image_path, embedding <-> '{query_embedding_list}' AS distance FROM image_embeddings ORDER BY embedding <-> '{query_embedding_list}' LIMIT {top_n};" | |
#print(qry) | |
# Execute the similarity search query | |
cursor.execute( | |
qry | |
) | |
similar_images = cursor.fetchall() | |
cursor.close() | |
conn.close() | |
return similar_images | |
def main(query_image_path): | |
query_embedding = get_image_embedding(query_image_path) | |
similar_images = find_most_similar_images(query_embedding) | |
print("Most similar images:") | |
for image_path, similarity in similar_images: | |
print(f"Image: {image_path}, Distance: {similarity}") | |
if __name__ == "__main__": | |
# Replace with the path to the query image | |
query_image_path = "path/to/your/query/image.jpg" | |
main(query_image_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment