Skip to content

Instantly share code, notes, and snippets.

@geobabbler
Created July 19, 2024 11:05
Show Gist options
  • Save geobabbler/a2ce0adc46b3b0a04b8f937541e02585 to your computer and use it in GitHub Desktop.
Save geobabbler/a2ce0adc46b3b0a04b8f937541e02585 to your computer and use it in GitHub Desktop.
Script to read folder of images, generate embeddings, and write them to pgvector.
import os
import psycopg2
from PIL import Image
import torch
from torchvision import models, transforms
# Database configuration
DB_NAME = "database"
DB_USER = "user"
DB_PASSWORD = "password"
DB_HOST = "host"
DB_PORT = "port"
# Folder containing images
IMAGE_FOLDER = "/path/to/sample_images"
# Load pre-trained ResNet model
model = models.resnet50(pretrained=True)
model.eval() # Set to evaluation mode
# Remove the final classification layer to get embeddings
model = torch.nn.Sequential(*list(model.children())[:-1])
# Image preprocessing
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def get_image_embedding(image_path):
image = Image.open(image_path).convert("RGB")
image_tensor = preprocess(image).unsqueeze(0)
with torch.no_grad():
embedding = model(image_tensor)
return embedding.squeeze().numpy()
def save_embedding_to_db(image_name, embedding):
conn = psycopg2.connect(
dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD, host=DB_HOST, port=DB_PORT
)
cursor = conn.cursor()
cursor.execute(
"INSERT INTO image_embeddings (image_path, embedding) VALUES (%s, %s)",
(image_name, embedding.tolist())
)
conn.commit()
cursor.close()
conn.close()
def process_images(image_folder):
for image_name in os.listdir(image_folder):
image_path = os.path.join(image_folder, image_name)
if os.path.isfile(image_path) and image_name.lower().endswith(('.png', '.jpg', '.jpeg')):
embedding = get_image_embedding(image_path)
save_embedding_to_db(image_path, embedding)
if __name__ == "__main__":
process_images(IMAGE_FOLDER)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment