Created
July 19, 2024 11:05
-
-
Save geobabbler/a2ce0adc46b3b0a04b8f937541e02585 to your computer and use it in GitHub Desktop.
Script to read folder of images, generate embeddings, and write them to pgvector.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import psycopg2 | |
from PIL import Image | |
import torch | |
from torchvision import models, transforms | |
# Database configuration | |
DB_NAME = "database" | |
DB_USER = "user" | |
DB_PASSWORD = "password" | |
DB_HOST = "host" | |
DB_PORT = "port" | |
# Folder containing images | |
IMAGE_FOLDER = "/path/to/sample_images" | |
# Load pre-trained ResNet model | |
model = models.resnet50(pretrained=True) | |
model.eval() # Set to evaluation mode | |
# Remove the final classification layer to get embeddings | |
model = torch.nn.Sequential(*list(model.children())[:-1]) | |
# Image preprocessing | |
preprocess = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
def get_image_embedding(image_path): | |
image = Image.open(image_path).convert("RGB") | |
image_tensor = preprocess(image).unsqueeze(0) | |
with torch.no_grad(): | |
embedding = model(image_tensor) | |
return embedding.squeeze().numpy() | |
def save_embedding_to_db(image_name, embedding): | |
conn = psycopg2.connect( | |
dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD, host=DB_HOST, port=DB_PORT | |
) | |
cursor = conn.cursor() | |
cursor.execute( | |
"INSERT INTO image_embeddings (image_path, embedding) VALUES (%s, %s)", | |
(image_name, embedding.tolist()) | |
) | |
conn.commit() | |
cursor.close() | |
conn.close() | |
def process_images(image_folder): | |
for image_name in os.listdir(image_folder): | |
image_path = os.path.join(image_folder, image_name) | |
if os.path.isfile(image_path) and image_name.lower().endswith(('.png', '.jpg', '.jpeg')): | |
embedding = get_image_embedding(image_path) | |
save_embedding_to_db(image_path, embedding) | |
if __name__ == "__main__": | |
process_images(IMAGE_FOLDER) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment