Skip to content

Instantly share code, notes, and snippets.

@companje
Created February 24, 2025 12:02
Show Gist options
  • Save companje/8d31cbd5aaa9f8e24cc48eee5426ef97 to your computer and use it in GitHub Desktop.
Save companje/8d31cbd5aaa9f8e24cc48eee5426ef97 to your computer and use it in GitHub Desktop.
FAISS
import faiss
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import os
# --- 1. Load Pre-trained Model (ResNet50) ---
class FeatureExtractor:
def __init__(self):
self.model = models.resnet50(pretrained=True)
self.model = torch.nn.Sequential(*list(self.model.children())[:-1]) # Remove last layer
self.model.eval()
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def extract(self, img_path):
image = Image.open(img_path).convert("RGB")
image = self.transform(image).unsqueeze(0)
with torch.no_grad():
feature = self.model(image)
return feature.squeeze().numpy().flatten()
# --- 2. Load Images & Extract Features ---
def build_feature_database(image_folder):
extractor = FeatureExtractor()
image_paths = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.endswith(('jpg', 'png'))]
features = []
for img_path in image_paths:
features.append(extractor.extract(img_path))
features = np.array(features).astype('float32')
# --- 3. Build FAISS Index ---
d = features.shape[1] # Feature dimension
index = faiss.IndexFlatL2(d)
index.add(features)
return index, image_paths
# --- 4. Search for Similar Images ---
def search_similar_image(query_img, index, image_paths, top_k=5):
extractor = FeatureExtractor()
query_vector = extractor.extract(query_img).astype('float32').reshape(1, -1)
distances, indices = index.search(query_vector, top_k)
results = [(image_paths[i], distances[0][j]) for j, i in enumerate(indices[0])]
return results
# --- Example Usage ---
image_folder = "beeldbank" # Map naar jouw dataset
index, image_paths = build_feature_database(image_folder)
query_img = "halfjaartjewerk/X225308 - 1003100.jpg" # Plaatje om te vergelijken
results = search_similar_image(query_img, index, image_paths)
for img_path, dist in results:
print(f"Match: {img_path}, Distance: {dist}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment