Created
December 3, 2018 16:25
-
-
Save matsui528/5683d0c9b2dc55c38b97e95879b94821 to your computer and use it in GitHub Desktop.
Hyper-parameter tuning for faiss using optuna
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Test for faiss with optuna using siftsmall data | |
# | |
# (1) install libs: | |
# $ pip install optuna | |
# $ conda install faiss-cpu -c pytorch | |
# | |
# (2) Put the following util scripts in the same directory | |
# https://github.com/matsui528/rii/blob/master/examples/benchmark/util.py | |
# | |
# (3) download siftsmall data | |
# $ wget ftp://ftp.irisa.fr/local/texmex/corpus/siftsmall.tar.gz -P data | |
# $ tar -zxvf data/siftsmall.tar.gz -C data | |
# | |
# (4) run the script | |
# $ python run_optuna.py | |
import optuna | |
import faiss | |
import numpy as np | |
import time | |
import util | |
# Read data (train, base, query, groundtruth) | |
Xt = util.fvecs_read("./data/siftsmall/siftsmall_learn.fvecs") | |
Xb = util.fvecs_read("./data/siftsmall/siftsmall_base.fvecs") | |
Xq = util.fvecs_read("./data/siftsmall/siftsmall_query.fvecs") | |
gt = util.ivecs_read("./data/siftsmall/siftsmall_groundtruth.ivecs") | |
D = Xt.shape[1] | |
def run_search(index, Xq, gt, r): | |
""" | |
Given a faiss index, run the search. Return the runtime and the accuracy | |
Args: | |
index (faiss index): Faiss index for search | |
Xq (np.array): Query vectors. shape=(Nq, D). dtype=np.float32 | |
gt (np.array): Groundtruth. shape=(Nq, ANY). dtype=np.int32 | |
r (int): Top R | |
Returns: | |
(float, float): Duration [sec/query] and recall@r over the queries | |
""" | |
assert Xq.ndim == 2 | |
assert Xq.dtype == np.float32 | |
Nq = Xq.shape[0] | |
t0 = time.time() | |
_, I = index.search(x=Xq, k=r) | |
t1 = time.time() | |
duration = (t1 - t0) / Nq # sec/query | |
recall = util.recall_at_r(I, gt, r) | |
return duration, recall | |
def objective(trial): | |
# Setup parameters to be optimized | |
M = int(trial.suggest_categorical('M', ['4', '8', '16'])) | |
nlist = trial.suggest_int('nlist', 10, 1000) | |
hnsw_m = trial.suggest_int('hnsw_m', 8, 64) | |
# Instantiate | |
quantizer = faiss.IndexHNSWFlat(D, hnsw_m) | |
index = faiss.IndexIVFPQ(quantizer, D, nlist, M, 8) | |
# Train the system,add base vectors to be searched | |
index.train(Xt) | |
index.add(Xb) | |
# Run search | |
duration, recall = run_search(index, Xq, gt, 1) | |
return -recall # flip recall, then min is better | |
# Run Optuna | |
study = optuna.create_study() | |
study.optimize(objective, n_trials=100) | |
# Show the best result | |
trial = study.best_trial | |
print("best recall:", -trial.value) | |
print("params:") | |
for k, v in trial.params.items(): | |
print('{}: {}'.format(k, v)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment