Created
July 19, 2017 08:39
-
-
Save mdouze/e393931abc9f8ed93e2f63516db5e4f4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import faiss | |
# make a GPU index | |
d = 16 | |
nq = 5 | |
nb = 20 | |
xq = faiss.randn(nq * d, 1234).reshape(nq, d) | |
xb = faiss.randn(nb * d, 1235).reshape(nb, d) | |
res = faiss.StandardGpuResources() | |
index = faiss.GpuIndexFlatIP(res, d) | |
index.add(xb) | |
# reference CPU result | |
Dref, Iref = index.search(xq, 5) | |
print Iref | |
def torch_tensor_to_numpy(x): | |
""" make a numpy array that views the same data as a torch tensor """ | |
assert x.is_contiguous() | |
ptr = x.storage().data_ptr() | |
ptr = faiss.cast_integer_to_float_ptr(ptr) | |
o = faiss.rev_swig_ptr(ptr, x.nelement()) | |
shape = tuple(x.size()) | |
return o.reshape(shape) | |
# test with CPU torch | |
xq_torch = torch.FloatTensor(xq) | |
xq_2 = torch_tensor_to_numpy(xq_torch) | |
D2, I2 = index.search(xq, 5) | |
print I2 | |
xq_torch = xq_torch.cuda() | |
xq_3 = torch_tensor_to_numpy(xq_torch) | |
# do not try to print out xq_3 because the data pointer is not a CPU pointer | |
D3, I3 = index.search(xq_3, 5) | |
print I3 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment