Created
May 2, 2018 11:21
-
-
Save mdouze/da4e7969c177afda3173192c3375e81a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include <faiss/IndexIVF.h> | |
| #include <faiss/AutoTune.h> | |
| #include <VectorTransform.h> | |
| /* Returns the cluster the embeddings belong to. | |
| * | |
| * @param index Index, which should be an IVF index | |
| * (otherwise there are no clusters) | |
| * @param query_centroid_ids | |
| * centroid ids corresponding to the query vectors (size n) | |
| * @param result_centroid_ids | |
| * centroid ids corresponding to the results (size n * k) | |
| * other arguments are the same as the standard search function | |
| */ | |
| void search_and_retrun_centroids(faiss::Index *index, | |
| size_t n, | |
| const float* xin, | |
| long k, | |
| float *distances, | |
| int64_t* labels, | |
| int64_t* query_centroid_ids, | |
| int64_t* result_centroid_ids) | |
| { | |
| const float *x = xin; | |
| std::unique_ptr<float []> del; | |
| if (auto index_pre = dynamic_cast<faiss::IndexPreTransform*>(index)) { | |
| x = index_pre->apply_chain(n, x); | |
| del.reset((float*)x); | |
| index = index_pre->index; | |
| } | |
| faiss::IndexIVF* index_ivf = dynamic_cast<faiss::IndexIVF*>(index); | |
| assert(index_ivf); | |
| size_t nprobe = index_ivf->nprobe; | |
| std::vector<long> cent_nos (n * nprobe); | |
| std::vector<float> cent_dis (n * nprobe); | |
| index_ivf->quantizer->search( | |
| n, x, nprobe, cent_dis.data(), cent_nos.data()); | |
| if (query_centroid_ids) { | |
| for (size_t i = 0; i < n; i++) | |
| query_centroid_ids[i] = cent_nos[i * nprobe]; | |
| } | |
| index_ivf->search_preassigned (n, x, k, | |
| cent_nos.data(), cent_dis.data(), | |
| distances, labels, true); | |
| for (size_t i = 0; i < n * k; i++) { | |
| int64_t label = labels[i]; | |
| if (label < 0) { | |
| if (result_centroid_ids) | |
| result_centroid_ids[i] = -1; | |
| } else { | |
| long list_no = label >> 32; | |
| long list_index = label & 0xffffffff; | |
| if (result_centroid_ids) | |
| result_centroid_ids[i] = list_no; | |
| labels[i] = index_ivf->invlists->get_single_id(list_no, list_index); | |
| } | |
| } | |
| } | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment