Created
August 26, 2018 12:00
-
-
Save snakers4/9e834c97e3f689aeb68d31b16f93be32 to your computer and use it in GitHub Desktop.
Use faiss to calculate a KNN graph on data
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gc | |
import tqdm | |
import faiss | |
import bcolz | |
import os,sys | |
import numpy as np | |
from tqdm import tqdm | |
# open the stored bcolz array | |
# note that these vectors have to be 280 dimensional | |
# to be compatible with faiss indexing | |
# https://github.com/facebookresearch/faiss/wiki/Troubleshooting#gpu-precomputed-table-error | |
bc_path = 'your_vectors.bc' | |
bc_vectors = bcolz.open(rootdir=bc_path) | |
vectors = bc_vectors[:,:] | |
# create a bcolz array for a knn graph | |
knn_bc_path = 'knn.bc' | |
knn_bc = bcolz.carray(rootdir=knn_bc_path, mode='w') | |
knn_bc.flush() | |
# create a bcolz array for distances | |
knn_dist_bc_path = 'distances.bc' | |
knn_dist_bc = bcolz.carray(rootdir=knn_dist_bc_path, mode='w') | |
knn_dist_bc.flush() | |
res = faiss.StandardGpuResources() | |
index = faiss.index_factory(vectors.shape[1], "IVF4096,PQ56") | |
co = faiss.GpuClonerOptions() | |
# https://github.com/facebookresearch/faiss/tree/master/benchs | |
# here we are using a 64-byte PQ, so we must set the lookup tables to | |
# 16 bit float (this is due to the limited temporary memory). | |
co.useFloat16 = True | |
index = faiss.index_cpu_to_gpu(res, 0, index, co) | |
print("Train the index") | |
index.train(vectors) | |
print ('Add vectors to the index') | |
index.add(vectors) | |
del vectors | |
gc.collect() | |
nprobe = 1 << 8 | |
index.setNumProbes(nprobe) | |
batch_size = int(16384/2) | |
l = list(range(0,len(bc_vectors))) | |
batches = [l[i:i + batch_size] for i in range(0, len(bc_vectors), batch_size)] | |
# check that the operation is valid | |
assert set([item for sublist in batches for item in sublist]) == set(list(range(0,len(bc_vectors)))) | |
processed_batches = [] | |
with tqdm(total=len(batches)) as pbar: | |
for batch in batches: | |
processed_batches.append(batch) | |
b_array = np.asarray(batch) | |
D, I = index.search(bc_vectors[b_array], 100) | |
knn_bc.append(I) | |
knn_bc.flush() | |
knn_dist_bc.append(D) | |
knn_dist_bc.flush() | |
pbar.update(1) | |
# check that all vectors were processed | |
assert set([item for sublist in processed_batches for item in sublist]) == set(list(range(0,len(bc_vectors)))) | |
assert len(knn_bc) == len(bc_vectors) | |
assert len(knn_dist_bc) == len(bc_vectors) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment