Created
November 6, 2020 09:31
-
-
Save rom1504/5a6a733dfd4772e497ea159e0395d0f9 to your computer and use it in GitHub Desktop.
faiss_benchmark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import faiss | |
import numpy as np | |
import re | |
import time | |
def index_factory(d: int, index_key: str, metric_type: int): | |
""" | |
custom index_factory that fix some issues of | |
faiss.index_factory with inner product metrics. | |
""" | |
if metric_type == faiss.METRIC_INNER_PRODUCT: | |
# make the index described by the key | |
if any(re.findall(r"OPQ\d+_\d+,IVF\d+_HNSW\d+,PQ\d+", index_key)): | |
params = [int(x) for x in re.findall(r"\d+", index_key)] | |
M_HNSW = params[3] | |
cs = params[4] # code size (in Bytes if nbits=8) | |
nbits = params[5] if len(params) == 6 else 8 # default value | |
ncentroids = params[2] | |
out_d = params[1] | |
M_OPQ = params[0] | |
quantizer = faiss.IndexHNSWFlat(out_d, M_HNSW, metric_type) | |
assert quantizer.metric_type == metric_type | |
index_ivfpq = faiss.IndexIVFPQ(quantizer, out_d, ncentroids, cs, nbits, metric_type) | |
assert index_ivfpq.metric_type == metric_type | |
index_ivfpq.own_fields = True | |
quantizer.this.disown() # pylint: disable = no-member | |
opq_matrix = faiss.OPQMatrix(d, M=M_OPQ, d2=out_d) | |
# opq_matrix.niter = 50 # Same as default value | |
index = faiss.IndexPreTransform(opq_matrix, index_ivfpq) | |
elif any(re.findall(r"Pad\d+,IVF\d+_HNSW\d+,PQ\d+", index_key)): | |
params = [int(x) for x in re.findall(r"\d+", index_key)] | |
out_d = params[0] | |
M_HNSW = params[2] | |
cs = params[3] # code size (in Bytes if nbits=8) | |
nbits = params[4] if len(params) == 5 else 8 # default value | |
ncentroids = params[1] | |
remapper = faiss.RemapDimensionsTransform(d, out_d, True) | |
quantizer = faiss.IndexHNSWFlat(out_d, M_HNSW, metric_type) | |
index_ivfpq = faiss.IndexIVFPQ(quantizer, out_d, ncentroids, cs, nbits, metric_type) | |
index_ivfpq.own_fields = True | |
quantizer.this.disown() # pylint: disable = no-member | |
index = faiss.IndexPreTransform(remapper, index_ivfpq) | |
else: | |
index = faiss.index_factory(d, index_key, metric_type) | |
raise ValueError( | |
( | |
"Be careful, faiss might not create what you expect when using the " | |
"inner product similarity metric, remove this line to try it anyway." | |
) | |
) | |
else: | |
index = faiss.index_factory(d, index_key, metric_type) | |
return index | |
dimension = 8 | |
nbClusters = 8 | |
nb_vectors = 10000 | |
index = index_factory(dimension, f"OPQ4_8,IVF{nbClusters}_HNSW32,PQ4x8", faiss.METRIC_INNER_PRODUCT) | |
faiss.extract_index_ivf(index).set_direct_map_type(faiss.DirectMap.Array) | |
params = faiss.ParameterSpace() | |
param_str = "nprobe=10,efSearch=20,ht=512" | |
params.set_index_parameters(index, param_str) | |
vectors = np.random.rand(nb_vectors, dimension).astype('float32') | |
faiss.omp_set_num_threads(1) | |
print("starting training") | |
start_time = time.perf_counter() | |
index.train(vectors) | |
end_time = time.perf_counter() | |
ellapsed_time = end_time - start_time | |
print(f"-> Trained in {ellapsed_time:.3} seconds") | |
index.add(vectors) | |
total = 0 | |
print("starting reconstruct") | |
start_time = time.perf_counter() | |
for j in range(10): | |
for i in range(nb_vectors): | |
index.reconstruct(i) | |
total +=1 | |
end_time = time.perf_counter() | |
ellapsed_time = end_time - start_time | |
reconstruct_time = ellapsed_time/(10*nb_vectors) | |
print(f"-> reconstruct in {reconstruct_time*1000} ms") | |
total = 0 | |
print("starting reconstruct and search") | |
start_time = time.perf_counter() | |
for j in range(10): | |
for i in range(nb_vectors): | |
vec = index.reconstruct(i) | |
index.search(np.expand_dims(vec, 0), 20) | |
total +=1 | |
end_time = time.perf_counter() | |
ellapsed_time = end_time - start_time | |
reconstruct_time = ellapsed_time/(10*nb_vectors) | |
print(f"-> reconstruct and search in {reconstruct_time*1000} ms") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment