Created
June 2, 2020 20:55
-
-
Save mdouze/6733ad21bc856a334ec0ba2f4a859406 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/faiss/index_factory.cpp b/faiss/index_factory.cpp | |
| --- a/faiss/index_factory.cpp | |
| +++ b/faiss/index_factory.cpp | |
| @@ -81,6 +81,7 @@ | |
| int64_t ncentroids = -1; | |
| bool use_2layer = false; | |
| + int hnsw_M = -1; | |
| for (char *tok = strtok_r (description, " ,", &ptr); | |
| tok; | |
| @@ -180,6 +181,8 @@ | |
| del_coarse_quantizer.release (); | |
| index_ivf->own_fields = true; | |
| index_1 = index_ivf; | |
| + } else if (hnsw_M > 0) { | |
| + index_1 = new IndexHNSWFlat (d, hnsw_M, metric); | |
| } else { | |
| FAISS_THROW_IF_NOT_MSG (stok != "FlatDedup", | |
| "dedup supported only for IVFFlat"); | |
| @@ -203,6 +206,8 @@ | |
| del_coarse_quantizer.release (); | |
| index_ivf->own_fields = true; | |
| index_1 = index_ivf; | |
| + } else if (hnsw_M > 0) { | |
| + index_1 = new IndexHNSWSQ(d, qt, hnsw_M, metric); | |
| } else { | |
| index_1 = new IndexScalarQuantizer (d, qt, metric); | |
| } | |
| @@ -242,6 +247,11 @@ | |
| index_2l->q1.own_fields = true; | |
| index_1 = index_2l; | |
| } | |
| + } else if (hnsw_M > 0) { | |
| + IndexHNSWPQ *ipq = new IndexHNSWPQ(d, M, hnsw_M); | |
| + dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training = | |
| + do_polysemous_training; | |
| + index_1 = ipq; | |
| } else { | |
| IndexPQ *index_pq = new IndexPQ (d, M, nbit, metric); | |
| index_pq->do_polysemous_training = do_polysemous_training; | |
| @@ -266,13 +276,14 @@ | |
| } else if (!index && | |
| sscanf (tok, "HNSW%d_PQ%d", &M, &pq_m) == 2) { | |
| index_1 = new IndexHNSWPQ (d, pq_m, M); | |
| - } else if (!index && | |
| - sscanf (tok, "HNSW%d", &M) == 1) { | |
| - index_1 = new IndexHNSWFlat (d, M); | |
| } else if (!index && | |
| sscanf (tok, "HNSW%d_SQ%d", &M, &pq_m) == 2 && | |
| pq_m == 8) { | |
| index_1 = new IndexHNSWSQ (d, ScalarQuantizer::QT_8bit, M); | |
| + } else if (!index && | |
| + sscanf (tok, "HNSW%d", &M) == 1) { | |
| + hnsw_M = M; | |
| + // here it is unclear what we want: HNSW flat or HNSWx,Y ? | |
| } else if (!index && (stok == "LSH" || stok == "LSHr" || | |
| stok == "LSHrt" || stok == "LSHt")) { | |
| bool rotate_data = strstr(tok, "r") != nullptr; | |
| @@ -312,6 +323,11 @@ | |
| } | |
| } | |
| + if (!index && hnsw_M > 0) { | |
| + index = new IndexHNSWFlat (d, hnsw_M, metric); | |
| + del_index.set (index); | |
| + } | |
| + | |
| FAISS_THROW_IF_NOT_FMT(index, "description %s did not generate an index", | |
| description_in); | |
| diff --git a/faiss/tests/test_factory.py b/faiss/tests/test_factory.py | |
| --- a/faiss/tests/test_factory.py | |
| +++ b/faiss/tests/test_factory.py | |
| @@ -47,6 +47,26 @@ | |
| index = faiss.index_factory(12, "IVF10,FlatDedup") | |
| assert index.instances is not None | |
| + def test_factory_HNSW(self): | |
| + index = faiss.index_factory(12, "HNSW32") | |
| + assert index.storage.sa_code_size() == 12 * 4 | |
| + index = faiss.index_factory(12, "HNSW32_SQ8") | |
| + assert index.storage.sa_code_size() == 12 | |
| + index = faiss.index_factory(12, "HNSW32_PQ4") | |
| + assert index.storage.sa_code_size() == 4 | |
| + | |
| + def test_factory_HNSW_newstyle(self): | |
| + index = faiss.index_factory(12, "HNSW32,Flat") | |
| + assert index.storage.sa_code_size() == 12 * 4 | |
| + index = faiss.index_factory(12, "HNSW32,SQ8", faiss.METRIC_INNER_PRODUCT) | |
| + assert index.storage.sa_code_size() == 12 | |
| + assert index.metric_type == faiss.METRIC_INNER_PRODUCT | |
| + index = faiss.index_factory(12, "HNSW32,PQ4") | |
| + assert index.storage.sa_code_size() == 4 | |
| + index = faiss.index_factory(12, "HNSW32,PQ4np") | |
| + indexpq = faiss.downcast_index(index.storage) | |
| + assert not indexpq.do_polysemous_training | |
| + | |
| class TestCloneSize(unittest.TestCase): | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment