Created
January 4, 2022 06:37
-
-
Save mayukh18/f2aea99e0fb04e7a075bf5fd95f0786b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!pip install faiss-gpu | |
import faiss | |
faiss_index = faiss.IndexFlatL2(1000) # build the index | |
# storing the image representations | |
im_indices = [] | |
with torch.no_grad(): | |
for f in glob.glob(os.path.join(PATH_TRAIN, '*/*')): | |
im = Image.open(f) | |
im = im.resize((224,224)) | |
im = torch.tensor([val_transforms(im).numpy()]).cuda() | |
preds = model(im) | |
preds = np.array([preds[0].cpu().numpy()]) | |
faiss_index.add(preds) #add the representation to index | |
im_indices.append(f) #store the image name to find it later on | |
# Retrieval with a query image | |
with torch.no_grad(): | |
for f in os.listdir(PATH_TEST): | |
# query/test image | |
im = Image.open(os.path.join(PATH_TEST,f)) | |
im = im.resize((224,224)) | |
im = torch.tensor([val_transforms(im).numpy()]).cuda() | |
test_embed = model(im).cpu().numpy() | |
_, I = faiss_index.search(test_embed, 5) | |
print("Retrieved Image: {}".format(im_indices[I[0][0]])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment