Skip to content

Instantly share code, notes, and snippets.

@benwtrent
Created September 18, 2024 19:06
Show Gist options
  • Save benwtrent/8ff7654b55f27e8605616092982520e2 to your computer and use it in GitHub Desktop.
Save benwtrent/8ff7654b55f27e8605616092982520e2 to your computer and use it in GitHub Desktop.
tools for reading and testing vector files
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.sandbox.rbq;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene912.Lucene912BinaryQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene912.Lucene912Codec;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.SuppressForbidden;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/** Class for testing binary quantization */
@SuppressForbidden(reason = "Used for testing")
public class RecallTest {
private static final double WRITER_BUFFER_MB = 1024;
private static final String DATA = "dbpedia-entity-arctic";
private static final String PATH = "/full/path/here/";
private static final int NUM_FILES = 5;
private static final int DIMS = 768;
private static final VectorSimilarityFunction SIMILARITY_FUNCTION =
VectorSimilarityFunction.DOT_PRODUCT;
private static final int OFFSET_BYTES = Integer.BYTES;
public static void index(
Codec codec, Path indexPath, Path docsPath, Path[] fvecPaths, int[] numVectors)
throws Exception {
if (!indexPath.toFile().exists()) {
indexPath.toFile().mkdirs();
} else {
for (Path fp : Files.walk(indexPath, 1).toList()) {
fp.toFile().delete();
}
}
IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.CREATE);
iwc.setCodec(codec);
iwc.setRAMBufferSizeMB(WRITER_BUFFER_MB);
iwc.setUseCompoundFile(false);
FieldType fieldType = KnnFloatVectorField.createFieldType(DIMS, SIMILARITY_FUNCTION);
long indexTime = System.currentTimeMillis();
try (FSDirectory dir = FSDirectory.open(indexPath);
MMapDirectory directory = new MMapDirectory(docsPath);
IndexWriter iw = new IndexWriter(dir, iwc)) {
int docID = 0;
for (int i = 0; i < fvecPaths.length; i++) {
Path fvecPath = fvecPaths[i];
try (IndexInput vectorInput = directory.openInput(fvecPath.toString(), IOContext.DEFAULT)) {
RandomAccessVectorValues.Floats vectorValues =
new VectorsReaderWithOffset(vectorInput, numVectors[i], DIMS, OFFSET_BYTES);
for (int j = 0; j < numVectors[i]; j++) {
Document doc = new Document();
float[] vector = vectorValues.vectorValue(j);
doc.add(new KnnFloatVectorField("knn", vector, fieldType));
doc.add(new StoredField("id", docID));
iw.addDocument(doc);
docID++;
}
}
}
iw.flush();
iw.commit();
long startForceMerge = System.currentTimeMillis();
iw.forceMerge(1);
System.out.println(
"Done force merge in: " + (System.currentTimeMillis() - startForceMerge) + "ms");
}
System.out.println("Done indexing in " + (System.currentTimeMillis() - indexTime) + "ms");
}
public static void search(Path indexPath, Path docsPath, Path queryPath, Path truekNN)
throws Exception {
try (FSDirectory dir = FSDirectory.open(indexPath);
MMapDirectory directory = new MMapDirectory(docsPath);
DirectoryReader reader = DirectoryReader.open(dir);
IndexInput queryVectorInput = directory.openInput(queryPath.toString(), IOContext.DEFAULT);
IndexInput truekNNInput = directory.openInput(truekNN.toString(), IOContext.DEFAULT)) {
RandomAccessVectorValues.Floats vectorValues =
new VectorsReaderWithOffset(queryVectorInput, 50, DIMS, OFFSET_BYTES);
TrueKnnReader trueKnnReader = new TrueKnnReader(truekNNInput, 50, 100);
long searchTimeSum = 0;
long searchTimeMax = 0;
long searchTimeMin = Long.MAX_VALUE;
float[] oversamples = new float[] {1f, 1.5f, 2f, 3f, 4f, 5f, 10f};
int[] overlapSum = new int[oversamples.length];
int[] overlapMax = new int[oversamples.length];
int[] overlapMin = new int[oversamples.length];
for (int k = 0; k < oversamples.length; k++) {
overlapMin[k] = Integer.MAX_VALUE;
}
for (int i = 0; i < 50; i++) {
long searchStart = System.currentTimeMillis();
float[] queryVector = vectorValues.vectorValue(i);
IndexSearcher searcher = new IndexSearcher(reader, null);
// calculate the overlap between the true kNN and the retrieved kNN
trueKnnReader.goToQuery(i);
KnnFloatVectorQuery q = new KnnFloatVectorQuery("knn", queryVector, 100 * 10);
TopDocs td = searcher.search(q, 100 * 10);
long searchEnd = System.currentTimeMillis();
long searchTime = searchEnd - searchStart;
searchTimeSum += searchTime;
searchTimeMax = Math.max(searchTimeMax, searchTime);
searchTimeMin = Math.min(searchTimeMin, searchTime);
int[] docIDs = new int[100 * 10];
for (int j = 0; j < 100 * 10; j++) {
Document doc = searcher.storedFields().document(td.scoreDocs[j].doc);
docIDs[j] = Integer.parseInt(doc.get("id"));
}
List<Set<Integer>> docIDsList = new ArrayList<>(oversamples.length);
for (int idx = 0; idx < oversamples.length; idx++) {
Set<Integer> matchedDocIds = new HashSet<>();
for (int j = 0; j < (int) (100 * oversamples[idx]); j++) {
matchedDocIds.add(docIDs[j]);
}
docIDsList.add(matchedDocIds);
}
for (int idx = 0; idx < oversamples.length; idx++) {
int overlap = 0;
for (int j = 0; j < 100; j++) {
if (docIDsList.get(idx).contains(trueKnnReader.ids[j])) {
overlap++;
}
}
overlapSum[idx] += overlap;
overlapMax[idx] = Math.max(overlapMax[idx], overlap);
overlapMin[idx] = Math.min(overlapMin[idx], overlap);
}
}
System.out.println("search time mean: " + (searchTimeSum / 50) + "ms max: " + searchTimeMax + "ms min: " + searchTimeMin + "ms");
for (int idx = 0; idx < oversamples.length; idx++) {
System.out.println(
"oversample: " + oversamples[idx] + " min: " + overlapMin[idx]+ " max: " + overlapMax[idx] + " mean: " + (overlapSum[idx] / 50));
}
}
}
public static void main(String[] args) throws Exception {
int[] centroidCounts = new int[] { 128, 255 };
final Path truekNN = Paths.get(PATH, "nearest-neighbours-" + DATA + ".bin");
final Path[] fvecPaths = new Path[NUM_FILES];
final int[] numVectors = new int[NUM_FILES];
int totalVectorsSum = 0;
for (int i = 0; i < NUM_FILES; i++) {
fvecPaths[i] = Paths.get(PATH, "corpus-" + DATA + "-" + i + ".fvec");
// get file size in bytes
long fileSize = Files.size(fvecPaths[i]);
// given the file size & offset bytes, calculate the number of vectors
if (fileSize % (DIMS * Float.BYTES + OFFSET_BYTES) != 0) {
throw new IllegalArgumentException("File size is not a multiple of vector size");
}
numVectors[i] = (int) (fileSize / (DIMS * Float.BYTES + OFFSET_BYTES));
totalVectorsSum += numVectors[i];
}
final int totalVectors = totalVectorsSum;
final Path docsPath = Paths.get(PATH);
final Path queryPath = Paths.get(PATH, "queries-" + DATA + "-0.fvec");
System.out.println("Data set: " + DATA + " Total vectors: " + totalVectors);
for (int centroidCount : centroidCounts) {
System.out.println("Indexing with " + centroidCount + " centroids\n\n");
Path indexPath =
Paths.get(
"/Users/benjamintrent/Projects/lucene-bench/util/recall-" + DATA + "-" + centroidCount);
// 50 queries with 100 nearest neighbors
Codec codec =
new Lucene912Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
if (centroidCount == 1) {
return new Lucene912BinaryQuantizedVectorsFormat();
} else {
return new Lucene912BinaryQuantizedVectorsFormat(totalVectors / centroidCount);
}
}
};
index(codec, indexPath, docsPath, fvecPaths, numVectors);
search(indexPath, docsPath, queryPath, truekNN);
}
}
static class TrueKnnReader {
final IndexInput in;
final int numQueries;
final int knn;
int currentQuery = -1;
final int[] ids;
final float[] scores;
final int byteSize;
public TrueKnnReader(IndexInput in, int numQueries, int knn) {
this.in = in;
this.numQueries = numQueries;
this.knn = knn;
ids = new int[knn];
scores = new float[knn];
byteSize = knn * (Integer.BYTES + Float.BYTES) + Integer.BYTES;
}
public void goToQuery(int id) throws IOException {
if (id == currentQuery) {
return;
}
in.seek((long) id * byteSize);
// query id
int queryid = in.readInt();
if (queryid != id) {
throw new IllegalStateException("Expected query id " + id + " but got " + queryid);
}
for (int i = 0; i < knn; i++) {
ids[i] = in.readInt();
}
in.readFloats(scores, 0, knn);
}
}
}
# read values from nearest-neighbours-dbpedia-entity-E5-small.json
# Of the form {"queryID": {"docID": score, "docID": score, ...}, ...}
# Example {"0": {"0": 0.5, "1": 0.4, ...}, ...}
# and transform it into flat byte format `queryid id score id score ...`
import json
import struct
import sys
def read_json_file(filename):
with open(filename, 'r') as f:
return json.load(f)
def write_byte_file(filename, data):
with open(filename, 'wb') as f:
for query_id, doc_scores in data.items():
f.write(struct.pack('i', int(query_id)))
# first write all doc_ids
for doc_id in doc_scores.keys():
f.write(struct.pack('i', int(doc_id)))
# then write all scores
for score in doc_scores.values():
f.write(struct.pack('f', score))
def transform_nn(input_file, output_file):
data = read_json_file(input_file)
write_byte_file(output_file, data)
if __name__ == '__main__':
transform_nn("/full/path/here/nearest-neighbours-dbpedia-entity-arctic.json","/full/path/here/nearest-neighbours-dbpedia-entity-arctic.bin")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment