Created
October 19, 2017 16:29
-
-
Save mdouze/51187d2eb7d271e5a963a3ccb23ac444 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--[[ Copyright 2004-present Facebook. All Rights Reserved. | |
Load the dynamic library involved in swig calls for Faiss | |
+ some code to convert input / output arguments | |
the standard way of importing this is | |
swigfaiss = require 'faiss_swig' | |
if you require 'swigfaiss' directly, you get the indexing | |
structures without the additions that make them easy to use from | |
Lua. | |
--]] | |
require 'torch' | |
local ffi = require 'ffi' | |
local swigfaiss | |
g_force_faiss_Makefile = g_force_faiss_Makefile or | |
os.getenv('FORCE_FAISS_MAKEFILE') | |
--[[ | |
Two global variables tell us what we should load: | |
- g_swigfaiss_use_gpu: if true, load swigfaiss_gpu. The GPU version | |
is a superset of the normal one, but has additional dependencies | |
of course. | |
- g_force_faiss_Makefile: if true, load the version compiled with | |
the Makefile rather than the version compiled by the Facebook | |
build system. | |
--]] | |
local module_name | |
if g_swigfaiss_use_gpu then module_name = 'swigfaiss_gpu' | |
else module_name = 'swigfaiss' end | |
if g_force_faiss_Makefile then | |
local my_path = '/fbsource/fbcode/deeplearning/projects/faiss' | |
local soname = os.getenv ("HOME") .. my_path .. '/lua/' .. module_name .. '.so' | |
local luaopen_swigfaiss = package.loadlib( | |
soname, 'luaopen_' .. module_name) | |
assert (luaopen_swigfaiss, 'could not load .so, check link flags') | |
swigfaiss = luaopen_swigfaiss() | |
else | |
-- works when compiled in fbcode | |
if g_swigfaiss_use_gpu then | |
swigfaiss = require 'swigfaiss_gpu' | |
else | |
swigfaiss = require 'swigfaiss' | |
end | |
end | |
--[[ | |
The train, add and search methods should accept Lua tensors as arguments. | |
To do this we: | |
- rename th swig-wrapped version of each function with a _c suffix | |
- allocate and convert the arguments to C | |
- call the _c version | |
The code below is relatively primitive. It assumes that all classes | |
whose name starts with Index or GpuIndex are the ones to be processed. | |
--]] | |
local function replace_c_method (class, name, f, allow_fail) | |
local mt_index = getmetatable(class)['.instance']['.fn'] | |
if allow_fail and not mt_index[name] then | |
return | |
end | |
-- move the C function aside | |
mt_index[name .. '_c'] = mt_index[name] | |
-- replace it with a Lua-friendly version | |
mt_index[name] = f | |
end | |
local function replace_search_method (class) | |
local f = function (index, xq, k, zero_based) | |
assert (xq:size (2) == index.d, | |
'vectors have incorrect dimension for search') | |
local nq = xq:size (1) | |
local D = torch.FloatTensor (nq, k) | |
local I = torch.LongTensor (nq, k) | |
index:search_c (nq, swigfaiss.float_ptr (xq), k, | |
swigfaiss.float_ptr (D), | |
swigfaiss.long_ptr (I)) | |
if zero_based == nil then | |
I:add (1) | |
end | |
return D, I | |
end | |
replace_c_method (class, 'search', f) | |
end | |
local function replace_train_method (class) | |
local f = function (index, x) | |
assert (x:size (2) == index.d, | |
'vectors have incorrect dimension for train') | |
index:train_c (x:size (1), swigfaiss.float_ptr (x)) | |
end | |
replace_c_method (class, 'train', f) | |
end | |
local function replace_add_method (class) | |
local f = function (index, x) | |
assert (x:size (2) == index.d, | |
'vectors have incorrect dimension for add') | |
index:add_c (x:size (1), swigfaiss.float_ptr (x)) | |
end | |
replace_c_method (class, 'add', f) | |
end | |
local function replace_add_with_ids_method (class, allow_fail) | |
local f = function (index, x, ids) | |
assert (x:size (2) == index.d) | |
local n = x:size(1) | |
if ids then | |
assert (ids:size(1) == n) | |
index:add_with_ids_c (n, swigfaiss.float_ptr (x), | |
swigfaiss.long_ptr(ids)) | |
else | |
index:add_with_ids_c (n, swigfaiss.float_ptr (x), nil) | |
end | |
end | |
replace_c_method (class, 'add_with_ids', f, allow_fail) | |
end | |
-- Make a few setters/getters that access Index variables directly. | |
-- We may move to more setters/getters in C++ as well. | |
local function add_set_get (class) | |
local mt_index = getmetatable(class)['.instance']['.fn'] | |
mt_index['get_n'] = function (index) return index.ntotal end | |
mt_index['set_verbose'] = function (index, verbose) | |
index.verbose = verbose | |
end | |
end | |
local function add_set_nprobe (class) | |
local mt_index = getmetatable(class)['.instance']['.fn'] | |
mt_index['set_nprobe'] = function (index, nprobe) | |
index.nprobe = nprobe | |
end | |
end | |
-- go over all fields, pick out the Index classes | |
for name, class in pairs (swigfaiss) do | |
-- require('fb.debugger').enter() | |
if ((name:match ('Index.*') or name:match ('GpuIndex.*')) and | |
getmetatable (class)) then | |
replace_train_method (class) | |
replace_add_method (class) | |
replace_search_method (class) | |
replace_add_with_ids_method (class, true) | |
add_set_get (class) | |
if name:match ('IndexIVF.*') then | |
add_set_nprobe (class) | |
end | |
end | |
end | |
-- A few additional method replacements | |
local function replace_range_search_method (class) | |
local f = function (index, xq, threshold) | |
local nq = xq:size (1) | |
assert (xq:size(2) == index.d) | |
-- TODO use an object that allocates tensors directly | |
local rc = swigfaiss.RangeSearchResult (nq) | |
-- the actual range search | |
index:range_search_c (nq, swigfaiss.float_ptr (xq), threshold, rc) | |
-- copy the results to Tensors | |
local lims = torch.LongTensor (nq + 1) | |
swigfaiss.memcpy (swigfaiss.long_ptr (lims), rc.lims, (nq + 1) * 8) | |
local nres = lims[nq + 1] | |
local D = torch.FloatTensor (nres) | |
local I = torch.LongTensor (nres) | |
if nres > 0 then | |
swigfaiss.memcpy (swigfaiss.float_ptr (D), rc.distances, nres * 4) | |
swigfaiss.memcpy (swigfaiss.uint64_t_ptr (I), rc.labels, nres * 8) | |
I:add (1) -- Lua indexing... | |
lims:add (1) | |
end | |
return lims, D, I | |
end | |
replace_c_method (class, 'range_search', f) | |
end | |
replace_range_search_method (swigfaiss.IndexIVFFlat) | |
replace_range_search_method (swigfaiss.IndexFlat) | |
local function replace_vector_transform_apply (class) | |
local f = function (vt, x) | |
assert (x:size (2) == vt.d_in) | |
local y = torch.FloatTensor (x:size(1), vt.d_out) | |
vt:apply_noalloc (x:size(1), swigfaiss.float_ptr (x), | |
swigfaiss.float_ptr (y)) | |
return y | |
end | |
local f2 = function (vt, x) | |
assert (x:size (2) == vt.d_in, 'incorrect train data dim') | |
vt:train_c (x:size(1), swigfaiss.float_ptr (x)) | |
end | |
local f3 = function (vt, y) | |
assert (y:size (2) == vt.d_out) | |
local x = torch.FloatTensor (y:size(1), vt.d_in) | |
vt:reverse_transform_c (y:size(1), swigfaiss.float_ptr (y), | |
swigfaiss.float_ptr (x)) | |
return x | |
end | |
replace_c_method (class, 'apply', f, true) | |
replace_c_method (class, 'train', f2, true) | |
replace_c_method (class, 'reverse_transform', f3, true) | |
end | |
replace_vector_transform_apply (swigfaiss.VectorTransform) | |
replace_vector_transform_apply (swigfaiss.LinearTransform) | |
replace_vector_transform_apply (swigfaiss.ExternalTransform) | |
replace_vector_transform_apply (swigfaiss.RemapDimensionsTransform) | |
replace_vector_transform_apply (swigfaiss.OPQMatrix) | |
replace_vector_transform_apply (swigfaiss.PCAMatrix) | |
local function replace_encode (class) | |
local f = function (codec, x) | |
assert (x:size (2) == codec.d) | |
local y = torch.ByteTensor (x:size(1), codec.code_size) | |
codec:encode_c (x:size(1), swigfaiss.float_ptr (x), | |
swigfaiss.uint8_t_ptr (y)) | |
return y | |
end | |
local f2 = function (vt, x) | |
assert (x:size (2) == vt.d) | |
vt:train_c (x:size(1), swigfaiss.float_ptr (x)) | |
end | |
replace_c_method (class, 'encode', f, true) | |
replace_c_method (class, 'train', f2, true) | |
end | |
replace_encode (swigfaiss.BinaryCode) | |
local function replace_codec_functions (class) | |
local f = function (codec, x) | |
assert (x:size (2) == codec.d) | |
codec:train_c (x:size(1), swigfaiss.float_ptr (x)) | |
end | |
local f2 = function (codec, x) | |
assert (x:size (2) == codec.d) | |
local y = torch.ByteTensor (x:size(1), codec.code_size) | |
codec:compute_codes_c (swigfaiss.float_ptr (x), | |
swigfaiss.uint8_t_ptr (y), | |
x:size(1)) | |
return y | |
end | |
local f3 = function (codec, x) | |
assert (x:size (2) == codec.code_size) | |
local y = torch.FloatTensor (x:size(1), codec.d) | |
codec:decode_c ( | |
swigfaiss.uint8_t_ptr (x), | |
swigfaiss.float_ptr (y), | |
x:size(1) | |
) | |
return y | |
end | |
replace_c_method (class, 'train', f, true) | |
replace_c_method (class, 'compute_codes', f2, true) | |
replace_c_method (class, 'decode', f3, true) | |
end | |
replace_codec_functions (swigfaiss.ProductQuantizer) | |
local AsyncIndexSearch = torch.class (module_name .. '.AsyncIndexSearch') | |
function AsyncIndexSearch:__init (index, xq, k, zero_based) | |
assert (index.d == xq:size(2)) | |
local nq = xq:size (1) | |
self.D = torch.FloatTensor (nq, k) | |
self.I = torch.LongTensor (nq, k) | |
self.zero_based = zero_based | |
self.c = swigfaiss.AsyncIndexSearchC ( | |
index, nq, swigfaiss.float_ptr (xq), k, | |
swigfaiss.float_ptr (self.D), swigfaiss.long_ptr (self.I)) | |
end | |
function AsyncIndexSearch:join () | |
self.c:join () | |
if self.zero_based == nil then | |
self.I:add (1) | |
end | |
return self.D, self.I | |
end | |
swigfaiss.kmeans_clustering_c = swigfaiss.kmeans_clustering | |
swigfaiss.kmeans_clustering = function (x, k) | |
local d = x:size(2) | |
local centroids = torch.FloatTensor (k, d) | |
local res = swigfaiss.kmeans_clustering_c ( | |
d, x:size(1), k, swigfaiss.float_ptr (x), | |
swigfaiss.float_ptr(centroids)) | |
return centroids, res | |
end | |
swigfaiss.kmeans_clustering_gpu_c = swigfaiss.kmeans_clustering_gpu | |
swigfaiss.kmeans_clustering_gpu = function (ngpu, x, k, useFloat16) | |
local d = x:size(2) | |
if useFloat16 == nil then | |
useFloat16 = false | |
end | |
local centroids = torch.FloatTensor (k, d) | |
local res = swigfaiss.kmeans_clustering_gpu_c ( | |
ngpu, d, x:size(1), k, swigfaiss.float_ptr (x), | |
swigfaiss.float_ptr(centroids), | |
useFloat16, false) | |
return centroids, res | |
end | |
--[[ | |
A few low-level utility functions to manipulate swig objects | |
--]] | |
-- this is how SWIG stores the reference to any pointer | |
ffi.cdef [[ | |
typedef struct { | |
void *type; | |
int own; /* 1 if owned & must be destroyed */ | |
void *ptr; | |
} swig_lua_userdata; | |
]] | |
-- tell swig not to deallocate the pointer when Lua garbage collects | |
-- the object | |
function swigfaiss.disown_pointer (index) | |
getmetatable(index)['.fn'].__disown (index) | |
end | |
-- check whether the pointer is owned by Lua | |
function swigfaiss.owns_pointer (index) | |
return ffi.cast ('swig_lua_userdata *', index).own == 1 | |
end | |
-- make sure that object ofrom that is included with oto does not get | |
-- deallocated before oto is deleted | |
function swigfaiss.transfer_ownership (ofrom, index) | |
swigfaiss.disown_pointer (ofrom) | |
index.own_fields = true | |
end | |
function swigfaiss.vector_float_to_tensor (v) | |
local t = torch.FloatTensor (v:size()) | |
swigfaiss.memcpy (swigfaiss.float_ptr (t), | |
v:data(), v:size() * 4) | |
return t | |
end | |
function swigfaiss.tensor_to_vector_float (t, vf) | |
vf:resize (t:nElement ()) | |
swigfaiss.memcpy (vf:data(), swigfaiss.float_ptr (t), | |
vf:size() * 4) | |
end | |
return swigfaiss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment