Skip to content

Instantly share code, notes, and snippets.

@Jotschi
Created December 25, 2021 03:17
Show Gist options
  • Save Jotschi/cea21a72412bcba80c46b967e9c52b0f to your computer and use it in GitHub Desktop.
Save Jotschi/cea21a72412bcba80c46b967e9c52b0f to your computer and use it in GitHub Desktop.
HnswGraphTest
package io.metaloom.video4j.lucene;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.junit.Assert.assertEquals;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.SplittableRandom;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.RandomUtils;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90Codec;
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsReader;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.junit.Before;
import org.junit.Test;
class Vector2D {
float a;
float b;
public Vector2D(float a, float b) {
this.a = a;
this.b = b;
}
public float[] toArray() {
return new float[] { a, b };
}
public void print(int ord) {
System.out.println(ord + " => [" + String.format("%.02f", a) + "|" + String.format("%.02f", b) + "]");
}
}
public class HnswGraphTest {
public static final Path indexPath = Paths.get("target/index");
public static final int dim = 2;
public static final VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
public static final int maxConn = 16;
public static final int beamWidth = 10;
public static final long seed = RandomUtils.nextLong();
@Before
public void setupIndexDir() throws IOException {
File file = indexPath.toFile();
if (file.exists()) {
FileUtils.deleteDirectory(file);
}
}
@Test
public void testSearch() throws IOException {
// Prepare the test data (10 entries)
List<Vector2D> vectorData = createVectorData(10);
// Add a custom vector which is very close to our target
vectorData.add(new Vector2D(0.99f, 0.01f));
// Create the provider which will feed the vectors for the graph
VectorProvider vectors = new VectorProvider(vectorData);
VectorProvider v2 = vectors.copy();
HnswGraphBuilder builder = new HnswGraphBuilder(vectors, similarityFunction, maxConn, beamWidth, seed);
HnswGraph hnsw = builder.build(vectors);
// Run a search
NeighborQueue nn = HnswGraph.search(
new float[] { 1, 0 },
10,
10,
vectors.randomAccess(), // ? Why do I need to specify the graph values again?
similarityFunction, // ? Why can I specify a different similarityFunction for search. Should that not be the same that was used for graph creation?
hnsw,
null,
new SplittableRandom(RandomUtils.nextLong()));
// Print the results
System.out.println();
System.out.println("Searching for NN of 1:0");
System.out.println("Results: " + nn.size());
System.out.println("Top:" + nn.topNode());
Vector2D topVec = vectorData.get(nn.topNode());
topVec.print(nn.topNode());
for (int i = 0; i < nn.size(); i++) {
int id = nn.pop();
System.out.println("ID: " + id);
}
// Persist and read the data
try (MMapDirectory dir = new MMapDirectory(indexPath)) {
IndexWriterConfig iwc = new IndexWriterConfig()
.setCodec(
new Lucene90Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene90HnswVectorsFormat(maxConn, beamWidth);
}
});
// Write index
int nVec = 0, indexedDoc = 0;
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
while (v2.nextDoc() != NO_MORE_DOCS) {
while (indexedDoc < v2.docID()) {
// increment docId in the index by adding empty documents
iw.addDocument(new Document());
indexedDoc++;
}
Document doc = new Document();
doc.add(new KnnVectorField("field", v2.vectorValue(), similarityFunction));
doc.add(new StoredField("id", v2.docID()));
iw.addDocument(doc);
nVec++;
indexedDoc++;
}
}
// Read index
try (IndexReader reader = DirectoryReader.open(dir)) {
for (LeafReaderContext ctx : reader.leaves()) {
VectorValues values = ctx.reader().getVectorValues("field");
assertEquals(dim, values.dimension());
assertEquals(nVec, values.size());
assertEquals(vectorData.size(), ctx.reader().maxDoc());
assertEquals(vectorData.size(), ctx.reader().numDocs());
KnnGraphValues graphValues = ((Lucene90HnswVectorsReader) ((PerFieldKnnVectorsFormat.FieldsReader) ((CodecReader) ctx.reader())
.getVectorReader())
.getFieldReader("field"))
.getGraphValues("field");
}
}
}
}
private List<Vector2D> createVectorData(int len) {
// Just using a list for now to make it easier to matchup with document ids later on
List<Vector2D> set = new ArrayList<>();
for (int i = 0; i < len; i++) {
/*
* double piRadians = i / (double) len; float a = (float) Math.cos(Math.PI * piRadians); float b = (float) Math.sin(Math.PI * piRadians);
*/
float a = (float) Math.random();
float b = (float) Math.random();
Vector2D vec = new Vector2D(a, b);
vec.print(i);
set.add(vec);
}
return set;
}
}
class VectorProvider extends VectorValues implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
int doc = -1;
private final List<Vector2D> data;
public VectorProvider(List<Vector2D> data) {
this.data = data;
}
@Override
public RandomAccessVectorValues randomAccess() {
return new VectorProvider(data);
}
@Override
public float[] vectorValue(int ord) throws IOException {
Vector2D entry = data.get(ord);
return entry.toArray();
}
@Override
public BytesRef binaryValue(int targetOrd) throws IOException {
return null;
}
@Override
public int dimension() {
return 2;
}
@Override
public int size() {
return data.size();
}
@Override
public float[] vectorValue() throws IOException {
return vectorValue(doc);
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}
@Override
public int advance(int target) throws IOException {
if (target >= 0 && target < data.size()) {
doc = target;
} else {
doc = NO_MORE_DOCS;
}
return doc;
}
@Override
public long cost() {
return data.size();
}
public VectorProvider copy() {
return new VectorProvider(data);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment