Last active
April 3, 2021 07:11
-
-
Save mocobeta/5c174ee9fc6408470057a9e7d2020c45 to your computer and use it in GitHub Desktop.
PoC example for approximate vector search for Lucene
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* PoC Indexing/Querying example code for LUCENE-9004 | |
* @see https://github.com/mocobeta/lucene-solr-mirror/tree/jira/LUCENE-9004-aknn | |
*/ | |
public class VectorValuesFieldExample { | |
public static void main(String[] args) { | |
String indexDir = "/tmp/vector-search"; | |
String vectorField = "vector"; | |
int maxDoc = 100_000; | |
int numDims = 100; | |
VectorDocValues.DistanceFunction distFunc = VectorDocValues.DistanceFunction.MANHATTAN; | |
VectorValuesFieldExample test = new VectorValuesFieldExample(indexDir, vectorField, maxDoc, numDims, distFunc); | |
try { | |
test.indexVectors(); | |
float[] queryVector = test.generateRandomVector(numDims); | |
test.searchVectors(queryVector, 5); | |
} catch (IOException e) { | |
e.printStackTrace(); | |
} | |
} | |
String indexDir; | |
String vectorField; | |
int maxDoc; | |
int numDims; | |
VectorDocValues.DistanceFunction distFunc; | |
VectorValuesFieldExample(String indexDir, String vectorField, int maxDoc, int numDims, VectorDocValues.DistanceFunction distFunc) { | |
this.indexDir = indexDir; | |
this.vectorField = vectorField; | |
this.maxDoc = maxDoc; | |
this.numDims = numDims; | |
this.distFunc = distFunc; | |
} | |
private void indexVectors() throws IOException { | |
List<float[]> vectors = new ArrayList<>(); | |
for ( int i = 0; i < maxDoc; i++) { | |
vectors.add(generateRandomVector(numDims)); | |
} | |
try (Directory dir = FSDirectory.open(Paths.get(indexDir))) { | |
IndexWriterConfig config = new IndexWriterConfig(new StandardAnalyzer()); | |
config.setUseCompoundFile(false); | |
config.setCodec(Codec.forName("Lucene90")); | |
IndexWriter writer = new IndexWriter(dir, config); | |
VectorFieldType type = VectorFieldType.getType(numDims, distFunc); | |
long _start = System.currentTimeMillis(); | |
for (int i = 0; i < maxDoc; i++) { | |
Document doc = new Document(); | |
doc.add(new VectorValuesField(vectorField, vectors.get(i), type)); | |
writer.addDocument(doc); | |
} | |
long _end = System.currentTimeMillis(); | |
System.out.println("Elapsed (indexing " + maxDoc + " docs): " + (_end - _start) + " msec"); | |
writer.commit(); | |
} | |
} | |
private void searchVectors(float[] queryVector, int n) throws IOException { | |
try (Directory dir = FSDirectory.open(Paths.get(indexDir))) { | |
IndexReader reader = DirectoryReader.open(dir); | |
IndexSearcher searcher = new IndexSearcher(reader); | |
System.out.println("Query: " + Arrays.toString(queryVector)); | |
Query query = VectorValuesField.newKnnGraphQuery(vectorField, queryVector, distFunc, 100, reader); | |
long _start = System.currentTimeMillis(); | |
TopDocs result = searcher.search(query, n); | |
long _end = System.currentTimeMillis(); | |
System.out.println("Elapsed (searching top " + n + "docs): " + (_end - _start) + " msec"); | |
System.out.println("Total hits: " + result.totalHits); | |
int rank = 0; | |
for (ScoreDoc hit : result.scoreDocs) { | |
int doc = hit.doc; | |
LeafReaderContext ctx = reader.leaves().get(0); | |
VectorDocValues values = ctx.reader().getVectorDocValues(vectorField); | |
float[] value = values.retrieve(doc); | |
if (value != null) { | |
rank++; | |
float dist = VectorDocValues.distance(queryVector, value, distFunc); | |
System.out.println("Rank " + rank + ": doc " + doc + ", score " + hit.score + ", distance " + dist + " (value=" + Arrays.toString(value) + ")"); | |
} | |
} | |
reader.close(); | |
} | |
} | |
private float[] generateRandomVector(int numDims) { | |
ThreadLocalRandom random = ThreadLocalRandom.current(); | |
float[] vector = new float[numDims]; | |
for (int i = 0; i < numDims; i++) { | |
vector[i] = random.nextFloat(); | |
} | |
return vector; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment