Skip to content

Instantly share code, notes, and snippets.

@mocobeta
Last active April 3, 2021 07:11
Show Gist options
  • Save mocobeta/5c174ee9fc6408470057a9e7d2020c45 to your computer and use it in GitHub Desktop.
Save mocobeta/5c174ee9fc6408470057a9e7d2020c45 to your computer and use it in GitHub Desktop.
PoC example for approximate vector search for Lucene
/**
* PoC Indexing/Querying example code for LUCENE-9004
* @see https://github.com/mocobeta/lucene-solr-mirror/tree/jira/LUCENE-9004-aknn
*/
public class VectorValuesFieldExample {
public static void main(String[] args) {
String indexDir = "/tmp/vector-search";
String vectorField = "vector";
int maxDoc = 100_000;
int numDims = 100;
VectorDocValues.DistanceFunction distFunc = VectorDocValues.DistanceFunction.MANHATTAN;
VectorValuesFieldExample test = new VectorValuesFieldExample(indexDir, vectorField, maxDoc, numDims, distFunc);
try {
test.indexVectors();
float[] queryVector = test.generateRandomVector(numDims);
test.searchVectors(queryVector, 5);
} catch (IOException e) {
e.printStackTrace();
}
}
String indexDir;
String vectorField;
int maxDoc;
int numDims;
VectorDocValues.DistanceFunction distFunc;
VectorValuesFieldExample(String indexDir, String vectorField, int maxDoc, int numDims, VectorDocValues.DistanceFunction distFunc) {
this.indexDir = indexDir;
this.vectorField = vectorField;
this.maxDoc = maxDoc;
this.numDims = numDims;
this.distFunc = distFunc;
}
private void indexVectors() throws IOException {
List<float[]> vectors = new ArrayList<>();
for ( int i = 0; i < maxDoc; i++) {
vectors.add(generateRandomVector(numDims));
}
try (Directory dir = FSDirectory.open(Paths.get(indexDir))) {
IndexWriterConfig config = new IndexWriterConfig(new StandardAnalyzer());
config.setUseCompoundFile(false);
config.setCodec(Codec.forName("Lucene90"));
IndexWriter writer = new IndexWriter(dir, config);
VectorFieldType type = VectorFieldType.getType(numDims, distFunc);
long _start = System.currentTimeMillis();
for (int i = 0; i < maxDoc; i++) {
Document doc = new Document();
doc.add(new VectorValuesField(vectorField, vectors.get(i), type));
writer.addDocument(doc);
}
long _end = System.currentTimeMillis();
System.out.println("Elapsed (indexing " + maxDoc + " docs): " + (_end - _start) + " msec");
writer.commit();
}
}
private void searchVectors(float[] queryVector, int n) throws IOException {
try (Directory dir = FSDirectory.open(Paths.get(indexDir))) {
IndexReader reader = DirectoryReader.open(dir);
IndexSearcher searcher = new IndexSearcher(reader);
System.out.println("Query: " + Arrays.toString(queryVector));
Query query = VectorValuesField.newKnnGraphQuery(vectorField, queryVector, distFunc, 100, reader);
long _start = System.currentTimeMillis();
TopDocs result = searcher.search(query, n);
long _end = System.currentTimeMillis();
System.out.println("Elapsed (searching top " + n + "docs): " + (_end - _start) + " msec");
System.out.println("Total hits: " + result.totalHits);
int rank = 0;
for (ScoreDoc hit : result.scoreDocs) {
int doc = hit.doc;
LeafReaderContext ctx = reader.leaves().get(0);
VectorDocValues values = ctx.reader().getVectorDocValues(vectorField);
float[] value = values.retrieve(doc);
if (value != null) {
rank++;
float dist = VectorDocValues.distance(queryVector, value, distFunc);
System.out.println("Rank " + rank + ": doc " + doc + ", score " + hit.score + ", distance " + dist + " (value=" + Arrays.toString(value) + ")");
}
}
reader.close();
}
}
private float[] generateRandomVector(int numDims) {
ThreadLocalRandom random = ThreadLocalRandom.current();
float[] vector = new float[numDims];
for (int i = 0; i < numDims; i++) {
vector[i] = random.nextFloat();
}
return vector;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment