Created
December 25, 2021 03:17
-
-
Save Jotschi/cea21a72412bcba80c46b967e9c52b0f to your computer and use it in GitHub Desktop.
HnswGraphTest
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package io.metaloom.video4j.lucene; | |
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; | |
import static org.junit.Assert.assertEquals; | |
import java.io.File; | |
import java.io.IOException; | |
import java.nio.file.Path; | |
import java.nio.file.Paths; | |
import java.util.ArrayList; | |
import java.util.List; | |
import java.util.SplittableRandom; | |
import org.apache.commons.io.FileUtils; | |
import org.apache.commons.lang3.RandomUtils; | |
import org.apache.lucene.codecs.KnnVectorsFormat; | |
import org.apache.lucene.codecs.lucene90.Lucene90Codec; | |
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsFormat; | |
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsReader; | |
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; | |
import org.apache.lucene.document.Document; | |
import org.apache.lucene.document.KnnVectorField; | |
import org.apache.lucene.document.StoredField; | |
import org.apache.lucene.index.CodecReader; | |
import org.apache.lucene.index.DirectoryReader; | |
import org.apache.lucene.index.IndexReader; | |
import org.apache.lucene.index.IndexWriter; | |
import org.apache.lucene.index.IndexWriterConfig; | |
import org.apache.lucene.index.KnnGraphValues; | |
import org.apache.lucene.index.LeafReaderContext; | |
import org.apache.lucene.index.RandomAccessVectorValues; | |
import org.apache.lucene.index.RandomAccessVectorValuesProducer; | |
import org.apache.lucene.index.VectorSimilarityFunction; | |
import org.apache.lucene.index.VectorValues; | |
import org.apache.lucene.store.MMapDirectory; | |
import org.apache.lucene.util.BytesRef; | |
import org.apache.lucene.util.hnsw.HnswGraph; | |
import org.apache.lucene.util.hnsw.HnswGraphBuilder; | |
import org.apache.lucene.util.hnsw.NeighborQueue; | |
import org.junit.Before; | |
import org.junit.Test; | |
class Vector2D { | |
float a; | |
float b; | |
public Vector2D(float a, float b) { | |
this.a = a; | |
this.b = b; | |
} | |
public float[] toArray() { | |
return new float[] { a, b }; | |
} | |
public void print(int ord) { | |
System.out.println(ord + " => [" + String.format("%.02f", a) + "|" + String.format("%.02f", b) + "]"); | |
} | |
} | |
public class HnswGraphTest { | |
public static final Path indexPath = Paths.get("target/index"); | |
public static final int dim = 2; | |
public static final VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; | |
public static final int maxConn = 16; | |
public static final int beamWidth = 10; | |
public static final long seed = RandomUtils.nextLong(); | |
@Before | |
public void setupIndexDir() throws IOException { | |
File file = indexPath.toFile(); | |
if (file.exists()) { | |
FileUtils.deleteDirectory(file); | |
} | |
} | |
@Test | |
public void testSearch() throws IOException { | |
// Prepare the test data (10 entries) | |
List<Vector2D> vectorData = createVectorData(10); | |
// Add a custom vector which is very close to our target | |
vectorData.add(new Vector2D(0.99f, 0.01f)); | |
// Create the provider which will feed the vectors for the graph | |
VectorProvider vectors = new VectorProvider(vectorData); | |
VectorProvider v2 = vectors.copy(); | |
HnswGraphBuilder builder = new HnswGraphBuilder(vectors, similarityFunction, maxConn, beamWidth, seed); | |
HnswGraph hnsw = builder.build(vectors); | |
// Run a search | |
NeighborQueue nn = HnswGraph.search( | |
new float[] { 1, 0 }, | |
10, | |
10, | |
vectors.randomAccess(), // ? Why do I need to specify the graph values again? | |
similarityFunction, // ? Why can I specify a different similarityFunction for search. Should that not be the same that was used for graph creation? | |
hnsw, | |
null, | |
new SplittableRandom(RandomUtils.nextLong())); | |
// Print the results | |
System.out.println(); | |
System.out.println("Searching for NN of 1:0"); | |
System.out.println("Results: " + nn.size()); | |
System.out.println("Top:" + nn.topNode()); | |
Vector2D topVec = vectorData.get(nn.topNode()); | |
topVec.print(nn.topNode()); | |
for (int i = 0; i < nn.size(); i++) { | |
int id = nn.pop(); | |
System.out.println("ID: " + id); | |
} | |
// Persist and read the data | |
try (MMapDirectory dir = new MMapDirectory(indexPath)) { | |
IndexWriterConfig iwc = new IndexWriterConfig() | |
.setCodec( | |
new Lucene90Codec() { | |
@Override | |
public KnnVectorsFormat getKnnVectorsFormatForField(String field) { | |
return new Lucene90HnswVectorsFormat(maxConn, beamWidth); | |
} | |
}); | |
// Write index | |
int nVec = 0, indexedDoc = 0; | |
try (IndexWriter iw = new IndexWriter(dir, iwc)) { | |
while (v2.nextDoc() != NO_MORE_DOCS) { | |
while (indexedDoc < v2.docID()) { | |
// increment docId in the index by adding empty documents | |
iw.addDocument(new Document()); | |
indexedDoc++; | |
} | |
Document doc = new Document(); | |
doc.add(new KnnVectorField("field", v2.vectorValue(), similarityFunction)); | |
doc.add(new StoredField("id", v2.docID())); | |
iw.addDocument(doc); | |
nVec++; | |
indexedDoc++; | |
} | |
} | |
// Read index | |
try (IndexReader reader = DirectoryReader.open(dir)) { | |
for (LeafReaderContext ctx : reader.leaves()) { | |
VectorValues values = ctx.reader().getVectorValues("field"); | |
assertEquals(dim, values.dimension()); | |
assertEquals(nVec, values.size()); | |
assertEquals(vectorData.size(), ctx.reader().maxDoc()); | |
assertEquals(vectorData.size(), ctx.reader().numDocs()); | |
KnnGraphValues graphValues = ((Lucene90HnswVectorsReader) ((PerFieldKnnVectorsFormat.FieldsReader) ((CodecReader) ctx.reader()) | |
.getVectorReader()) | |
.getFieldReader("field")) | |
.getGraphValues("field"); | |
} | |
} | |
} | |
} | |
private List<Vector2D> createVectorData(int len) { | |
// Just using a list for now to make it easier to matchup with document ids later on | |
List<Vector2D> set = new ArrayList<>(); | |
for (int i = 0; i < len; i++) { | |
/* | |
* double piRadians = i / (double) len; float a = (float) Math.cos(Math.PI * piRadians); float b = (float) Math.sin(Math.PI * piRadians); | |
*/ | |
float a = (float) Math.random(); | |
float b = (float) Math.random(); | |
Vector2D vec = new Vector2D(a, b); | |
vec.print(i); | |
set.add(vec); | |
} | |
return set; | |
} | |
} | |
class VectorProvider extends VectorValues implements RandomAccessVectorValues, RandomAccessVectorValuesProducer { | |
int doc = -1; | |
private final List<Vector2D> data; | |
public VectorProvider(List<Vector2D> data) { | |
this.data = data; | |
} | |
@Override | |
public RandomAccessVectorValues randomAccess() { | |
return new VectorProvider(data); | |
} | |
@Override | |
public float[] vectorValue(int ord) throws IOException { | |
Vector2D entry = data.get(ord); | |
return entry.toArray(); | |
} | |
@Override | |
public BytesRef binaryValue(int targetOrd) throws IOException { | |
return null; | |
} | |
@Override | |
public int dimension() { | |
return 2; | |
} | |
@Override | |
public int size() { | |
return data.size(); | |
} | |
@Override | |
public float[] vectorValue() throws IOException { | |
return vectorValue(doc); | |
} | |
@Override | |
public int docID() { | |
return doc; | |
} | |
@Override | |
public int nextDoc() throws IOException { | |
return advance(doc + 1); | |
} | |
@Override | |
public int advance(int target) throws IOException { | |
if (target >= 0 && target < data.size()) { | |
doc = target; | |
} else { | |
doc = NO_MORE_DOCS; | |
} | |
return doc; | |
} | |
@Override | |
public long cost() { | |
return data.size(); | |
} | |
public VectorProvider copy() { | |
return new VectorProvider(data); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment