Created
February 24, 2025 12:02
-
-
Save companje/8d31cbd5aaa9f8e24cc48eee5426ef97 to your computer and use it in GitHub Desktop.
FAISS
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import faiss | |
import numpy as np | |
import torch | |
import torchvision.transforms as transforms | |
from torchvision import models | |
from PIL import Image | |
import os | |
# --- 1. Load Pre-trained Model (ResNet50) --- | |
class FeatureExtractor: | |
def __init__(self): | |
self.model = models.resnet50(pretrained=True) | |
self.model = torch.nn.Sequential(*list(self.model.children())[:-1]) # Remove last layer | |
self.model.eval() | |
self.transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
def extract(self, img_path): | |
image = Image.open(img_path).convert("RGB") | |
image = self.transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
feature = self.model(image) | |
return feature.squeeze().numpy().flatten() | |
# --- 2. Load Images & Extract Features --- | |
def build_feature_database(image_folder): | |
extractor = FeatureExtractor() | |
image_paths = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.endswith(('jpg', 'png'))] | |
features = [] | |
for img_path in image_paths: | |
features.append(extractor.extract(img_path)) | |
features = np.array(features).astype('float32') | |
# --- 3. Build FAISS Index --- | |
d = features.shape[1] # Feature dimension | |
index = faiss.IndexFlatL2(d) | |
index.add(features) | |
return index, image_paths | |
# --- 4. Search for Similar Images --- | |
def search_similar_image(query_img, index, image_paths, top_k=5): | |
extractor = FeatureExtractor() | |
query_vector = extractor.extract(query_img).astype('float32').reshape(1, -1) | |
distances, indices = index.search(query_vector, top_k) | |
results = [(image_paths[i], distances[0][j]) for j, i in enumerate(indices[0])] | |
return results | |
# --- Example Usage --- | |
image_folder = "beeldbank" # Map naar jouw dataset | |
index, image_paths = build_feature_database(image_folder) | |
query_img = "halfjaartjewerk/X225308 - 1003100.jpg" # Plaatje om te vergelijken | |
results = search_similar_image(query_img, index, image_paths) | |
for img_path, dist in results: | |
print(f"Match: {img_path}, Distance: {dist}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment