Created
January 18, 2019 17:17
-
-
Save DelightRun/e8bc139bcd2c588bdaea6652da4ff4af to your computer and use it in GitHub Desktop.
Python Scripts for Accurate KNN Graph Construction using FAISS and GPU
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import argparse | |
import faiss | |
import numpy as np | |
from utils import load_vecs, save_vecs | |
parser = argparse.ArgumentParser(description="Build KNN Graph using GPU") | |
parser.add_argument('dataset', type=str, help='Dataset') | |
parser.add_argument('k', type=int, nargs='+', help='K') | |
args = parser.parse_args() | |
dataset = args.dataset | |
k = args.k | |
ngpus = faiss.get_num_gpus() | |
print("Use %d GPUs" % ngpus) | |
print("Load Dataset %s" % dataset.upper()) | |
base = load_vecs('%s_base.fvecs' % (dataset, dataset)) | |
dim = base.shape[1] | |
print("Create Index") | |
index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(dim)) | |
print("Add Base Data to Index") | |
index.add(base) | |
for k in args.k: | |
print("Build %dNN Graph" % k) | |
t0 = time.time() | |
_, graph = index.search(base, k + 1) | |
t1 = time.time() | |
print("Time: %.3f" % (t1 - t0)) | |
print("Save Graph") | |
save_vecs(graph[:, 1:].astype(np.int32), | |
'%s_%d_gt.knng' % (dataset, dataset, k)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import numpy as np | |
def load_vecs(filename, c_contiguous=True, use_mmap=False): | |
_, ext = os.path.splitext(filename) | |
if ext == '.fvecs': | |
dtype = np.float32 | |
# TODO: bvecs | |
else: | |
dtype = np.int32 | |
if use_mmap: | |
data = np.memmap(filename, dtype=dtype, mode='r') | |
else: | |
data = np.fromfile(filename, dtype=dtype) | |
if data.size == 0: | |
return np.zeros((0, 0)) | |
dim = data.view(np.int32)[0] | |
assert dim > 0 | |
data = data.reshape(-1, 1 + dim) | |
if not all(data.view(np.int32)[:, 0] == dim): | |
raise IOError("Non-uniform vector sizes in " + filename) | |
data = data[:, 1:] | |
if c_contiguous: | |
data = data.copy() | |
return data | |
def save_vecs(data, filename): | |
dtype = data.dtype | |
dim = np.int32(data.shape[1]) | |
data = np.insert(data, 0, dim.view(dtype), axis=1) | |
data.tofile(filename) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment