Created
September 18, 2024 19:06
-
-
Save benwtrent/8ff7654b55f27e8605616092982520e2 to your computer and use it in GitHub Desktop.
tools for reading and testing vector files
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Licensed to the Apache Software Foundation (ASF) under one or more | |
* contributor license agreements. See the NOTICE file distributed with | |
* this work for additional information regarding copyright ownership. | |
* The ASF licenses this file to You under the Apache License, Version 2.0 | |
* (the "License"); you may not use this file except in compliance with | |
* the License. You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
package org.apache.lucene.sandbox.rbq; | |
import java.io.IOException; | |
import java.nio.file.Files; | |
import java.nio.file.Path; | |
import java.nio.file.Paths; | |
import java.util.ArrayList; | |
import java.util.HashSet; | |
import java.util.List; | |
import java.util.Set; | |
import org.apache.lucene.codecs.Codec; | |
import org.apache.lucene.codecs.KnnVectorsFormat; | |
import org.apache.lucene.codecs.lucene912.Lucene912BinaryQuantizedVectorsFormat; | |
import org.apache.lucene.codecs.lucene912.Lucene912Codec; | |
import org.apache.lucene.document.Document; | |
import org.apache.lucene.document.FieldType; | |
import org.apache.lucene.document.KnnFloatVectorField; | |
import org.apache.lucene.document.StoredField; | |
import org.apache.lucene.index.DirectoryReader; | |
import org.apache.lucene.index.IndexWriter; | |
import org.apache.lucene.index.IndexWriterConfig; | |
import org.apache.lucene.index.VectorSimilarityFunction; | |
import org.apache.lucene.search.IndexSearcher; | |
import org.apache.lucene.search.KnnFloatVectorQuery; | |
import org.apache.lucene.search.TopDocs; | |
import org.apache.lucene.store.FSDirectory; | |
import org.apache.lucene.store.IOContext; | |
import org.apache.lucene.store.IndexInput; | |
import org.apache.lucene.store.MMapDirectory; | |
import org.apache.lucene.util.SuppressForbidden; | |
import org.apache.lucene.util.hnsw.RandomAccessVectorValues; | |
/** Class for testing binary quantization */ | |
@SuppressForbidden(reason = "Used for testing") | |
public class RecallTest { | |
private static final double WRITER_BUFFER_MB = 1024; | |
private static final String DATA = "dbpedia-entity-arctic"; | |
private static final String PATH = "/full/path/here/"; | |
private static final int NUM_FILES = 5; | |
private static final int DIMS = 768; | |
private static final VectorSimilarityFunction SIMILARITY_FUNCTION = | |
VectorSimilarityFunction.DOT_PRODUCT; | |
private static final int OFFSET_BYTES = Integer.BYTES; | |
public static void index( | |
Codec codec, Path indexPath, Path docsPath, Path[] fvecPaths, int[] numVectors) | |
throws Exception { | |
if (!indexPath.toFile().exists()) { | |
indexPath.toFile().mkdirs(); | |
} else { | |
for (Path fp : Files.walk(indexPath, 1).toList()) { | |
fp.toFile().delete(); | |
} | |
} | |
IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.CREATE); | |
iwc.setCodec(codec); | |
iwc.setRAMBufferSizeMB(WRITER_BUFFER_MB); | |
iwc.setUseCompoundFile(false); | |
FieldType fieldType = KnnFloatVectorField.createFieldType(DIMS, SIMILARITY_FUNCTION); | |
long indexTime = System.currentTimeMillis(); | |
try (FSDirectory dir = FSDirectory.open(indexPath); | |
MMapDirectory directory = new MMapDirectory(docsPath); | |
IndexWriter iw = new IndexWriter(dir, iwc)) { | |
int docID = 0; | |
for (int i = 0; i < fvecPaths.length; i++) { | |
Path fvecPath = fvecPaths[i]; | |
try (IndexInput vectorInput = directory.openInput(fvecPath.toString(), IOContext.DEFAULT)) { | |
RandomAccessVectorValues.Floats vectorValues = | |
new VectorsReaderWithOffset(vectorInput, numVectors[i], DIMS, OFFSET_BYTES); | |
for (int j = 0; j < numVectors[i]; j++) { | |
Document doc = new Document(); | |
float[] vector = vectorValues.vectorValue(j); | |
doc.add(new KnnFloatVectorField("knn", vector, fieldType)); | |
doc.add(new StoredField("id", docID)); | |
iw.addDocument(doc); | |
docID++; | |
} | |
} | |
} | |
iw.flush(); | |
iw.commit(); | |
long startForceMerge = System.currentTimeMillis(); | |
iw.forceMerge(1); | |
System.out.println( | |
"Done force merge in: " + (System.currentTimeMillis() - startForceMerge) + "ms"); | |
} | |
System.out.println("Done indexing in " + (System.currentTimeMillis() - indexTime) + "ms"); | |
} | |
public static void search(Path indexPath, Path docsPath, Path queryPath, Path truekNN) | |
throws Exception { | |
try (FSDirectory dir = FSDirectory.open(indexPath); | |
MMapDirectory directory = new MMapDirectory(docsPath); | |
DirectoryReader reader = DirectoryReader.open(dir); | |
IndexInput queryVectorInput = directory.openInput(queryPath.toString(), IOContext.DEFAULT); | |
IndexInput truekNNInput = directory.openInput(truekNN.toString(), IOContext.DEFAULT)) { | |
RandomAccessVectorValues.Floats vectorValues = | |
new VectorsReaderWithOffset(queryVectorInput, 50, DIMS, OFFSET_BYTES); | |
TrueKnnReader trueKnnReader = new TrueKnnReader(truekNNInput, 50, 100); | |
long searchTimeSum = 0; | |
long searchTimeMax = 0; | |
long searchTimeMin = Long.MAX_VALUE; | |
float[] oversamples = new float[] {1f, 1.5f, 2f, 3f, 4f, 5f, 10f}; | |
int[] overlapSum = new int[oversamples.length]; | |
int[] overlapMax = new int[oversamples.length]; | |
int[] overlapMin = new int[oversamples.length]; | |
for (int k = 0; k < oversamples.length; k++) { | |
overlapMin[k] = Integer.MAX_VALUE; | |
} | |
for (int i = 0; i < 50; i++) { | |
long searchStart = System.currentTimeMillis(); | |
float[] queryVector = vectorValues.vectorValue(i); | |
IndexSearcher searcher = new IndexSearcher(reader, null); | |
// calculate the overlap between the true kNN and the retrieved kNN | |
trueKnnReader.goToQuery(i); | |
KnnFloatVectorQuery q = new KnnFloatVectorQuery("knn", queryVector, 100 * 10); | |
TopDocs td = searcher.search(q, 100 * 10); | |
long searchEnd = System.currentTimeMillis(); | |
long searchTime = searchEnd - searchStart; | |
searchTimeSum += searchTime; | |
searchTimeMax = Math.max(searchTimeMax, searchTime); | |
searchTimeMin = Math.min(searchTimeMin, searchTime); | |
int[] docIDs = new int[100 * 10]; | |
for (int j = 0; j < 100 * 10; j++) { | |
Document doc = searcher.storedFields().document(td.scoreDocs[j].doc); | |
docIDs[j] = Integer.parseInt(doc.get("id")); | |
} | |
List<Set<Integer>> docIDsList = new ArrayList<>(oversamples.length); | |
for (int idx = 0; idx < oversamples.length; idx++) { | |
Set<Integer> matchedDocIds = new HashSet<>(); | |
for (int j = 0; j < (int) (100 * oversamples[idx]); j++) { | |
matchedDocIds.add(docIDs[j]); | |
} | |
docIDsList.add(matchedDocIds); | |
} | |
for (int idx = 0; idx < oversamples.length; idx++) { | |
int overlap = 0; | |
for (int j = 0; j < 100; j++) { | |
if (docIDsList.get(idx).contains(trueKnnReader.ids[j])) { | |
overlap++; | |
} | |
} | |
overlapSum[idx] += overlap; | |
overlapMax[idx] = Math.max(overlapMax[idx], overlap); | |
overlapMin[idx] = Math.min(overlapMin[idx], overlap); | |
} | |
} | |
System.out.println("search time mean: " + (searchTimeSum / 50) + "ms max: " + searchTimeMax + "ms min: " + searchTimeMin + "ms"); | |
for (int idx = 0; idx < oversamples.length; idx++) { | |
System.out.println( | |
"oversample: " + oversamples[idx] + " min: " + overlapMin[idx]+ " max: " + overlapMax[idx] + " mean: " + (overlapSum[idx] / 50)); | |
} | |
} | |
} | |
public static void main(String[] args) throws Exception { | |
int[] centroidCounts = new int[] { 128, 255 }; | |
final Path truekNN = Paths.get(PATH, "nearest-neighbours-" + DATA + ".bin"); | |
final Path[] fvecPaths = new Path[NUM_FILES]; | |
final int[] numVectors = new int[NUM_FILES]; | |
int totalVectorsSum = 0; | |
for (int i = 0; i < NUM_FILES; i++) { | |
fvecPaths[i] = Paths.get(PATH, "corpus-" + DATA + "-" + i + ".fvec"); | |
// get file size in bytes | |
long fileSize = Files.size(fvecPaths[i]); | |
// given the file size & offset bytes, calculate the number of vectors | |
if (fileSize % (DIMS * Float.BYTES + OFFSET_BYTES) != 0) { | |
throw new IllegalArgumentException("File size is not a multiple of vector size"); | |
} | |
numVectors[i] = (int) (fileSize / (DIMS * Float.BYTES + OFFSET_BYTES)); | |
totalVectorsSum += numVectors[i]; | |
} | |
final int totalVectors = totalVectorsSum; | |
final Path docsPath = Paths.get(PATH); | |
final Path queryPath = Paths.get(PATH, "queries-" + DATA + "-0.fvec"); | |
System.out.println("Data set: " + DATA + " Total vectors: " + totalVectors); | |
for (int centroidCount : centroidCounts) { | |
System.out.println("Indexing with " + centroidCount + " centroids\n\n"); | |
Path indexPath = | |
Paths.get( | |
"/Users/benjamintrent/Projects/lucene-bench/util/recall-" + DATA + "-" + centroidCount); | |
// 50 queries with 100 nearest neighbors | |
Codec codec = | |
new Lucene912Codec() { | |
@Override | |
public KnnVectorsFormat getKnnVectorsFormatForField(String field) { | |
if (centroidCount == 1) { | |
return new Lucene912BinaryQuantizedVectorsFormat(); | |
} else { | |
return new Lucene912BinaryQuantizedVectorsFormat(totalVectors / centroidCount); | |
} | |
} | |
}; | |
index(codec, indexPath, docsPath, fvecPaths, numVectors); | |
search(indexPath, docsPath, queryPath, truekNN); | |
} | |
} | |
static class TrueKnnReader { | |
final IndexInput in; | |
final int numQueries; | |
final int knn; | |
int currentQuery = -1; | |
final int[] ids; | |
final float[] scores; | |
final int byteSize; | |
public TrueKnnReader(IndexInput in, int numQueries, int knn) { | |
this.in = in; | |
this.numQueries = numQueries; | |
this.knn = knn; | |
ids = new int[knn]; | |
scores = new float[knn]; | |
byteSize = knn * (Integer.BYTES + Float.BYTES) + Integer.BYTES; | |
} | |
public void goToQuery(int id) throws IOException { | |
if (id == currentQuery) { | |
return; | |
} | |
in.seek((long) id * byteSize); | |
// query id | |
int queryid = in.readInt(); | |
if (queryid != id) { | |
throw new IllegalStateException("Expected query id " + id + " but got " + queryid); | |
} | |
for (int i = 0; i < knn; i++) { | |
ids[i] = in.readInt(); | |
} | |
in.readFloats(scores, 0, knn); | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# read values from nearest-neighbours-dbpedia-entity-E5-small.json | |
# Of the form {"queryID": {"docID": score, "docID": score, ...}, ...} | |
# Example {"0": {"0": 0.5, "1": 0.4, ...}, ...} | |
# and transform it into flat byte format `queryid id score id score ...` | |
import json | |
import struct | |
import sys | |
def read_json_file(filename): | |
with open(filename, 'r') as f: | |
return json.load(f) | |
def write_byte_file(filename, data): | |
with open(filename, 'wb') as f: | |
for query_id, doc_scores in data.items(): | |
f.write(struct.pack('i', int(query_id))) | |
# first write all doc_ids | |
for doc_id in doc_scores.keys(): | |
f.write(struct.pack('i', int(doc_id))) | |
# then write all scores | |
for score in doc_scores.values(): | |
f.write(struct.pack('f', score)) | |
def transform_nn(input_file, output_file): | |
data = read_json_file(input_file) | |
write_byte_file(output_file, data) | |
if __name__ == '__main__': | |
transform_nn("/full/path/here/nearest-neighbours-dbpedia-entity-arctic.json","/full/path/here/nearest-neighbours-dbpedia-entity-arctic.bin") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment