Skip to content

Instantly share code, notes, and snippets.

@mdouze
Created June 2, 2020 20:55
Show Gist options
  • Select an option

  • Save mdouze/6733ad21bc856a334ec0ba2f4a859406 to your computer and use it in GitHub Desktop.

Select an option

Save mdouze/6733ad21bc856a334ec0ba2f4a859406 to your computer and use it in GitHub Desktop.
diff --git a/faiss/index_factory.cpp b/faiss/index_factory.cpp
--- a/faiss/index_factory.cpp
+++ b/faiss/index_factory.cpp
@@ -81,6 +81,7 @@
int64_t ncentroids = -1;
bool use_2layer = false;
+ int hnsw_M = -1;
for (char *tok = strtok_r (description, " ,", &ptr);
tok;
@@ -180,6 +181,8 @@
del_coarse_quantizer.release ();
index_ivf->own_fields = true;
index_1 = index_ivf;
+ } else if (hnsw_M > 0) {
+ index_1 = new IndexHNSWFlat (d, hnsw_M, metric);
} else {
FAISS_THROW_IF_NOT_MSG (stok != "FlatDedup",
"dedup supported only for IVFFlat");
@@ -203,6 +206,8 @@
del_coarse_quantizer.release ();
index_ivf->own_fields = true;
index_1 = index_ivf;
+ } else if (hnsw_M > 0) {
+ index_1 = new IndexHNSWSQ(d, qt, hnsw_M, metric);
} else {
index_1 = new IndexScalarQuantizer (d, qt, metric);
}
@@ -242,6 +247,11 @@
index_2l->q1.own_fields = true;
index_1 = index_2l;
}
+ } else if (hnsw_M > 0) {
+ IndexHNSWPQ *ipq = new IndexHNSWPQ(d, M, hnsw_M);
+ dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
+ do_polysemous_training;
+ index_1 = ipq;
} else {
IndexPQ *index_pq = new IndexPQ (d, M, nbit, metric);
index_pq->do_polysemous_training = do_polysemous_training;
@@ -266,13 +276,14 @@
} else if (!index &&
sscanf (tok, "HNSW%d_PQ%d", &M, &pq_m) == 2) {
index_1 = new IndexHNSWPQ (d, pq_m, M);
- } else if (!index &&
- sscanf (tok, "HNSW%d", &M) == 1) {
- index_1 = new IndexHNSWFlat (d, M);
} else if (!index &&
sscanf (tok, "HNSW%d_SQ%d", &M, &pq_m) == 2 &&
pq_m == 8) {
index_1 = new IndexHNSWSQ (d, ScalarQuantizer::QT_8bit, M);
+ } else if (!index &&
+ sscanf (tok, "HNSW%d", &M) == 1) {
+ hnsw_M = M;
+ // here it is unclear what we want: HNSW flat or HNSWx,Y ?
} else if (!index && (stok == "LSH" || stok == "LSHr" ||
stok == "LSHrt" || stok == "LSHt")) {
bool rotate_data = strstr(tok, "r") != nullptr;
@@ -312,6 +323,11 @@
}
}
+ if (!index && hnsw_M > 0) {
+ index = new IndexHNSWFlat (d, hnsw_M, metric);
+ del_index.set (index);
+ }
+
FAISS_THROW_IF_NOT_FMT(index, "description %s did not generate an index",
description_in);
diff --git a/faiss/tests/test_factory.py b/faiss/tests/test_factory.py
--- a/faiss/tests/test_factory.py
+++ b/faiss/tests/test_factory.py
@@ -47,6 +47,26 @@
index = faiss.index_factory(12, "IVF10,FlatDedup")
assert index.instances is not None
+ def test_factory_HNSW(self):
+ index = faiss.index_factory(12, "HNSW32")
+ assert index.storage.sa_code_size() == 12 * 4
+ index = faiss.index_factory(12, "HNSW32_SQ8")
+ assert index.storage.sa_code_size() == 12
+ index = faiss.index_factory(12, "HNSW32_PQ4")
+ assert index.storage.sa_code_size() == 4
+
+ def test_factory_HNSW_newstyle(self):
+ index = faiss.index_factory(12, "HNSW32,Flat")
+ assert index.storage.sa_code_size() == 12 * 4
+ index = faiss.index_factory(12, "HNSW32,SQ8", faiss.METRIC_INNER_PRODUCT)
+ assert index.storage.sa_code_size() == 12
+ assert index.metric_type == faiss.METRIC_INNER_PRODUCT
+ index = faiss.index_factory(12, "HNSW32,PQ4")
+ assert index.storage.sa_code_size() == 4
+ index = faiss.index_factory(12, "HNSW32,PQ4np")
+ indexpq = faiss.downcast_index(index.storage)
+ assert not indexpq.do_polysemous_training
+
class TestCloneSize(unittest.TestCase):
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment