Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@ Optimizations
* GITHUB#15868: Optimize DisjunctionMaxBulkScorer by reusing the inner LeafCollector across
sub-scorers and resetting windowScores inline during replay instead of Arrays.fill. (Prithvi S)

* GITHUB#15732: Prevent writing vectors twice during merging HNSW graphs by allowing doing deferred work after calling
merge for vectors is finshed. (Ignacio Vera)

Bug Fixes
---------------------
* GITHUB#15754: Fix HTMLStripCharFilter to prevent tags from incorrectly consuming subsequent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer.QuantizationResult;

Expand All @@ -54,7 +55,7 @@ public Lucene102BinaryFlatVectorsScorer(FlatVectorsScorer nonQuantizedDelegate)
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
throws IOException {
throw new UnsupportedOperationException("Old codecs may only be used for reading");
return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
}

@Override
Expand Down Expand Up @@ -98,6 +99,62 @@ public String toString() {
return "Lucene102BinaryFlatVectorsScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")";
}

RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction,
Lucene102BinaryQuantizedVectorsReader.OffHeapBinarizedQueryVectorValues scoringVectors,
BinarizedByteVectorValues targetVectors) {
return new BinarizedRandomVectorScorerSupplier(
scoringVectors, targetVectors, similarityFunction);
}

/** Vector scorer supplier over binarized vector values */
static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
private final Lucene102BinaryQuantizedVectorsReader.OffHeapBinarizedQueryVectorValues
queryVectors;
private final BinarizedByteVectorValues targetVectors;
private final VectorSimilarityFunction similarityFunction;

BinarizedRandomVectorScorerSupplier(
Lucene102BinaryQuantizedVectorsReader.OffHeapBinarizedQueryVectorValues queryVectors,
BinarizedByteVectorValues targetVectors,
VectorSimilarityFunction similarityFunction) {
this.queryVectors = queryVectors;
this.targetVectors = targetVectors;
this.similarityFunction = similarityFunction;
}

@Override
public UpdateableRandomVectorScorer scorer() throws IOException {
final BinarizedByteVectorValues targetVectors = this.targetVectors.copy();
final Lucene102BinaryQuantizedVectorsReader.OffHeapBinarizedQueryVectorValues queryVectors =
this.queryVectors.copy();
return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(targetVectors) {
private QuantizationResult queryCorrections = null;
private byte[] vector = null;

@Override
public void setScoringOrdinal(int node) throws IOException {
queryCorrections = queryVectors.getCorrectiveTerms(node);
vector = queryVectors.vectorValue(node);
}

@Override
public float score(int node) throws IOException {
if (vector == null || queryCorrections == null) {
throw new IllegalStateException("setScoringOrdinal was not called");
}
return quantizedScore(vector, queryCorrections, targetVectors, node, similarityFunction);
}
};
}

@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new BinarizedRandomVectorScorerSupplier(
queryVectors.copy(), targetVectors.copy(), similarityFunction);
}
}

static float quantizedScore(
byte[] quantizedQuery,
QuantizationResult queryCorrections,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@
*/
package org.apache.lucene.backward_codecs.lucene102;

import static org.apache.lucene.backward_codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS;
import static org.apache.lucene.backward_codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.VECTOR_DATA_EXTENSION;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding;
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize;
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.transposeHalfByte;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
Expand All @@ -31,11 +36,14 @@
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.AcceptDocs;
Expand All @@ -47,15 +55,23 @@
import org.apache.lucene.store.FileTypeHint;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.BaseQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
import org.apache.lucene.util.quantization.QuantizedVectorsReader;
import org.apache.lucene.util.quantization.ScalarQuantizer;

/** Reader for binary quantized vectors in the Lucene 10.2 format. */
public class Lucene102BinaryQuantizedVectorsReader extends FlatVectorsReader {
public class Lucene102BinaryQuantizedVectorsReader extends FlatVectorsReader
implements QuantizedVectorsReader {

private static final long SHALLOW_SIZE =
RamUsageEstimator.shallowSizeOfInstance(Lucene102BinaryQuantizedVectorsReader.class);
Expand Down Expand Up @@ -144,7 +160,7 @@ static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
+ fieldEntry.dimension);
}

int binaryDims = discretize(dimension, 64) / 8;
int binaryDims = discretize(dimension, Long.SIZE) / Byte.SIZE;
long numQuantizedVectorBytes =
Math.multiplyExact((binaryDims + (Float.BYTES * 3) + Short.BYTES), (long) fieldEntry.size);
if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) {
Expand Down Expand Up @@ -196,7 +212,7 @@ public void checkIntegrity() throws IOException {
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
public BinarizedVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fi = fields.get(field);
if (fi == null) {
return null;
Expand Down Expand Up @@ -342,6 +358,107 @@ private FieldEntry readField(IndexInput input, FieldInfo info) throws IOExceptio
return FieldEntry.create(input, vectorEncoding, info.getVectorSimilarityFunction());
}

@Override
public BaseQuantizedByteVectorValues getQuantizedVectorValues(String fieldName) {
return null;
}

@Override
public ScalarQuantizer getQuantizationState(String fieldName) {
return null;
}

@Override
public CloseableRandomVectorScorerSupplier getRandomVectorScorerSupplierForMerge(
Comment thread
iverase marked this conversation as resolved.
FieldInfo fieldInfo, SegmentWriteState segmentWriteState) throws IOException {
assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32);
OptimizedScalarQuantizer quantizer =
new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
String tempScoreQuantizedVectorName = null;
float[] centroid = getCentroid(fieldInfo.name);
BinarizedVectorValues vectorValues = getFloatVectorValues(fieldInfo.name);
FloatVectorValues floatVectorValues = vectorValues.rawVectorValues;
if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues);
}

DocsWithFieldSet docsWithField;
try (IndexOutput tempScoreQuantizedVector =
segmentWriteState.directory.createTempOutput(
segmentWriteState.segmentInfo.name, "queries", segmentWriteState.context)) {
tempScoreQuantizedVectorName = tempScoreQuantizedVector.getName();
docsWithField =
writeBinarizedQueryData(tempScoreQuantizedVector, floatVectorValues, centroid, quantizer);
CodecUtil.writeFooter(tempScoreQuantizedVector);
} catch (Throwable t) {
if (tempScoreQuantizedVectorName != null) {
IOUtils.deleteFilesSuppressingExceptions(
t, segmentWriteState.directory, tempScoreQuantizedVectorName);
}
throw t;
}
IndexInput quantizedScoreDataInput =
segmentWriteState.directory.openInput(
tempScoreQuantizedVectorName, segmentWriteState.context);
try {
RandomVectorScorerSupplier scorerSupplier =
vectorScorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
new OffHeapBinarizedQueryVectorValues(
quantizedScoreDataInput,
fieldInfo.getVectorDimension(),
docsWithField.cardinality()),
vectorValues.quantizedVectorValues);
final String finalTempScoreQuantizedVectorName = tempScoreQuantizedVectorName;
return CloseableRandomVectorScorerSupplier.create(
scorerSupplier,
vectorValues.size(),
() -> {
IOUtils.close(quantizedScoreDataInput);
IOUtils.deleteFilesIgnoringExceptions(
segmentWriteState.directory, finalTempScoreQuantizedVectorName);
});
} catch (Throwable t) {
IOUtils.closeWhileSuppressingExceptions(t, quantizedScoreDataInput);
IOUtils.deleteFilesSuppressingExceptions(
t, segmentWriteState.directory, tempScoreQuantizedVectorName);
throw t;
}
}

static DocsWithFieldSet writeBinarizedQueryData(
IndexOutput binarizedQueryData,
FloatVectorValues floatVectorValues,
float[] centroid,
OptimizedScalarQuantizer binaryQuantizer)
throws IOException {
int discretizedDimension = discretize(floatVectorValues.dimension(), 64);
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
byte[] quantizationScratch = new byte[floatVectorValues.dimension()];
byte[] toQuery = new byte[(discretizedDimension / 8) * QUERY_BITS];
KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) {
// write index vector
OptimizedScalarQuantizer.QuantizationResult r =
binaryQuantizer.scalarQuantize(
floatVectorValues.vectorValue(iterator.index()),
quantizationScratch,
QUERY_BITS,
centroid);
docsWithField.add(docV);

// pack and store the 4bit query vector
transposeHalfByte(quantizationScratch, toQuery);
binarizedQueryData.writeBytes(toQuery, toQuery.length);
binarizedQueryData.writeInt(Float.floatToIntBits(r.lowerInterval()));
binarizedQueryData.writeInt(Float.floatToIntBits(r.upperInterval()));
binarizedQueryData.writeInt(Float.floatToIntBits(r.additionalCorrection()));
assert r.quantizedComponentSum() >= 0 && r.quantizedComponentSum() <= 0xffff;
binarizedQueryData.writeShort((short) r.quantizedComponentSum());
}
return docsWithField;
}

private record FieldEntry(
VectorSimilarityFunction similarityFunction,
VectorEncoding vectorEncoding,
Expand Down Expand Up @@ -448,4 +565,115 @@ BinarizedByteVectorValues getQuantizedVectorValues() throws IOException {
return quantizedVectorValues;
}
}

// When accessing vectorValue method, targetOrd here means a row ordinal.
static class OffHeapBinarizedQueryVectorValues {
private final IndexInput slice;
private final int dimension;
private final int size;
protected final byte[] binaryValue;
protected final ByteBuffer byteBuffer;
private final int byteSize;
protected final float[] correctiveValues;
private int lastOrd = -1;
private int quantizedComponentSum;

OffHeapBinarizedQueryVectorValues(IndexInput data, int dimension, int size) {
this.slice = data;
this.dimension = dimension;
this.size = size;
// 4x the quantized binary dimensions
int binaryDimensions = (discretize(dimension, Long.SIZE) / Byte.SIZE) * QUERY_BITS;
this.byteBuffer = ByteBuffer.allocate(binaryDimensions);
this.binaryValue = byteBuffer.array();
// + 1 for the quantized sum
this.correctiveValues = new float[3];
this.byteSize = binaryDimensions + Float.BYTES * 3 + Short.BYTES;
}

public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int targetOrd)
throws IOException {
if (lastOrd == targetOrd) {
return new OptimizedScalarQuantizer.QuantizationResult(
correctiveValues[0], correctiveValues[1], correctiveValues[2], quantizedComponentSum);
}
vectorValue(targetOrd);
return new OptimizedScalarQuantizer.QuantizationResult(
correctiveValues[0], correctiveValues[1], correctiveValues[2], quantizedComponentSum);
}

public int size() {
return size;
}

public int quantizedLength() {
return binaryValue.length;
}

public int dimension() {
return dimension;
}

public OffHeapBinarizedQueryVectorValues copy() throws IOException {
return new OffHeapBinarizedQueryVectorValues(slice.clone(), dimension, size);
}

public IndexInput getSlice() {
return slice;
}

public byte[] vectorValue(int targetOrd) throws IOException {
if (lastOrd == targetOrd) {
return binaryValue;
}
slice.seek((long) targetOrd * byteSize);
slice.readBytes(binaryValue, 0, binaryValue.length);
slice.readFloats(correctiveValues, 0, 3);
quantizedComponentSum = Short.toUnsignedInt(slice.readShort());
lastOrd = targetOrd;
return binaryValue;
}
}

static final class NormalizedFloatVectorValues extends FloatVectorValues {
private final FloatVectorValues values;
private final float[] normalizedVector;

NormalizedFloatVectorValues(FloatVectorValues values) {
this.values = values;
this.normalizedVector = new float[values.dimension()];
}

@Override
public int dimension() {
return values.dimension();
}

@Override
public int size() {
return values.size();
}

@Override
public int ordToDoc(int ord) {
return values.ordToDoc(ord);
}

@Override
public float[] vectorValue(int ord) throws IOException {
System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length);
VectorUtil.l2normalize(normalizedVector);
return normalizedVector;
}

@Override
public DocIndexIterator iterator() {
return values.iterator();
}

@Override
public NormalizedFloatVectorValues copy() throws IOException {
return new NormalizedFloatVectorValues(values.copy());
}
}
}
Loading
Loading