Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ public NeighborWithShortEdges(Neighbors neighbors, double shortEdges) {
}
}

private static class NeighborIterator implements NodesIterator {
public static class NeighborIterator implements NodesIterator {
private final NodeArray neighbors;
private int i;

Expand All @@ -374,5 +374,9 @@ public boolean hasNext() {
public int nextInt() {
return neighbors.getNode(i++);
}

public NodeArray merge(NodeArray other) {
return NodeArray.merge(neighbors, other);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.ImmutableGraphIndex.NodeAtLevel;
import io.github.jbellis.jvector.graph.SearchResult.NodeScore;
import io.github.jbellis.jvector.graph.disk.NeighborsScoreCache;
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
import io.github.jbellis.jvector.graph.diversity.VamanaDiversityProvider;
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
Expand All @@ -30,6 +32,7 @@
import io.github.jbellis.jvector.util.PhysicalCoreExecutor;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import org.agrona.collections.IntArrayList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -324,7 +327,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider,
this.simdExecutor = simdExecutor;
this.parallelExecutor = parallelExecutor;

this.graph = new OnHeapGraphIndex(maxDegrees, neighborOverflow, new VamanaDiversityProvider(scoreProvider, alpha));
this.graph = new OnHeapGraphIndex(dimension, maxDegrees, neighborOverflow, new VamanaDiversityProvider(scoreProvider, alpha));
this.searchers = ExplicitThreadLocal.withInitial(() -> {
var gs = new GraphSearcher(graph);
gs.usePruning(false);
Expand All @@ -338,6 +341,50 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider,
this.rng = new Random(0);
}

/**
* Create this builder from an existing {@link io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex}, this is useful when we just loaded a graph from disk
* copy it into {@link OnHeapGraphIndex} and then start mutating it with minimal overhead of recreating the mutable {@link OnHeapGraphIndex} used in the new GraphIndexBuilder object
*
* @param buildScoreProvider the provider responsible for calculating build scores.
* @param immutableGraphIndex the on-disk representation of the graph index to be processed and converted.
* @param perLevelNeighborsScoreCache the cache containing pre-computed neighbor scores,
* organized by levels and nodes.
* @param beamWidth the width of the beam used during the graph building process.
* @param neighborOverflow the factor determining how many additional neighbors are allowed beyond the configured limit.
* @param alpha the weight factor for balancing score computations.
* @param addHierarchy whether to add hierarchical structures while building the graph.
* @param refineFinalGraph whether to perform a refinement step on the final graph structure.
* @param simdExecutor the ForkJoinPool executor used for SIMD tasks during graph building.
* @param parallelExecutor the ForkJoinPool executor used for general parallelization during graph building.
*
* @throws IOException if an I/O error occurs during the graph loading or conversion process.
*/
public GraphIndexBuilder(BuildScoreProvider buildScoreProvider, ImmutableGraphIndex immutableGraphIndex, NeighborsScoreCache perLevelNeighborsScoreCache, int beamWidth, float neighborOverflow, float alpha, boolean addHierarchy, boolean refineFinalGraph, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) throws IOException {
this.scoreProvider = buildScoreProvider;
this.neighborOverflow = neighborOverflow;
this.dimension = immutableGraphIndex.getDimension();
this.alpha = alpha;
this.addHierarchy = addHierarchy;
this.refineFinalGraph = refineFinalGraph;
this.beamWidth = beamWidth;
this.simdExecutor = simdExecutor;
this.parallelExecutor = parallelExecutor;

this.graph = OnHeapGraphIndex.convertToHeap(immutableGraphIndex, perLevelNeighborsScoreCache, buildScoreProvider, neighborOverflow, alpha);

this.searchers = ExplicitThreadLocal.withInitial(() -> {
var gs = new GraphSearcher(graph);
gs.usePruning(false);
return gs;
});

// in scratch, we store candidates in reverse order: worse candidates are first
this.naturalScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1)));
this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1)));

this.rng = new Random(0);
}

// used by Cassandra when it fine-tunes the PQ codebook
public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvider newProvider) {
var newBuilder = new GraphIndexBuilder(newProvider,
Expand Down Expand Up @@ -750,7 +797,7 @@ public synchronized long removeDeletedNodes() {
return memorySize;
}

private void updateNeighbors(int level, int nodeId, NodeArray natural, NodeArray concurrent) {
private void updateNeighbors(int layer, int nodeId, NodeArray natural, NodeArray concurrent) {
// if either natural or concurrent is empty, skip the merge
NodeArray toMerge;
if (concurrent.size() == 0) {
Expand All @@ -761,7 +808,7 @@ private void updateNeighbors(int level, int nodeId, NodeArray natural, NodeArray
toMerge = NodeArray.merge(natural, concurrent);
}
// toMerge may be approximate-scored, but insertDiverse will compute exact scores for the diverse ones
graph.addEdges(level, nodeId, toMerge, neighborOverflow);
graph.addEdges(layer, nodeId, toMerge, neighborOverflow);
}

private static NodeArray toScratchCandidates(NodeScore[] candidates, NodeArray scratch) {
Expand Down Expand Up @@ -876,6 +923,7 @@ private void loadV4(RandomAccessReader in) throws IOException {
graph.updateEntryNode(new NodeAtLevel(graph.getMaxLevel(), entryNode));
}


@Deprecated
private void loadV3(RandomAccessReader in, int size) throws IOException {
if (graph.size() != 0) {
Expand Down Expand Up @@ -909,4 +957,58 @@ private void loadV3(RandomAccessReader in, int size) throws IOException {
graph.updateEntryNode(new NodeAtLevel(0, entryNode));
graph.setDegrees(List.of(maxDegree));
}

/**
* Convenience method to build a new graph from an existing one, with the addition of new nodes.
* This is useful when we want to merge a new set of vectors into an existing graph that is already on disk.
*
* @param immutableGraphIndex the immutable (usually on-disk) representation of the graph index to be processed and converted.
* @param perLevelNeighborsScoreCache the cache containing pre-computed neighbor scores,
* @param newVectors a super set RAVV containing the new vectors to be added to the graph as well as the old ones that are already in the graph
* @param buildScoreProvider the provider responsible for calculating build scores.
* @param startingNodeOffset the offset in the newVectors RAVV where the new vectors start
* @param graphToRavvOrdMap a mapping from the old graph's node ids to the newVectors RAVV node ids
* @param beamWidth the width of the beam used during the graph building process.
* @param overflowRatio the ratio of extra neighbors to allow temporarily when inserting a node.
* @param alpha the weight factor for balancing score computations.
* @param addHierarchy whether to add hierarchical structures while building the graph.
*
* @return the in-memory representation of the graph index.
* @throws IOException if an I/O error occurs during the graph loading or conversion process.
*/
public static ImmutableGraphIndex buildAndMergeNewNodes(ImmutableGraphIndex immutableGraphIndex,
NeighborsScoreCache perLevelNeighborsScoreCache,
RandomAccessVectorValues newVectors,
BuildScoreProvider buildScoreProvider,
int startingNodeOffset,
int[] graphToRavvOrdMap,
int beamWidth,
float overflowRatio,
float alpha,
boolean addHierarchy) throws IOException {



try (GraphIndexBuilder builder = new GraphIndexBuilder(buildScoreProvider,
immutableGraphIndex,
perLevelNeighborsScoreCache,
beamWidth,
overflowRatio,
alpha,
addHierarchy,
true,
PhysicalCoreExecutor.pool(),
ForkJoinPool.commonPool())) {

var vv = newVectors.threadLocalSupplier();

// parallel graph construction from the merge documents Ids
PhysicalCoreExecutor.pool().submit(() -> IntStream.range(startingNodeOffset, newVectors.size()).parallel().forEach(ord -> {
builder.addGraphNode(ord, vv.get().getVector(graphToRavvOrdMap[ord]));
})).join();

builder.cleanup();
return builder.getGraph();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ default int size() {
*/
NodesIterator getNodes(int level);

/** Return the dimension of the vectors in the graph */
int getDimension();

/**
* Return a View with which to navigate the graph. Views are not threadsafe -- that is,
* only one search at a time should be run per View.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.graph.ConcurrentNeighborMap.Neighbors;
import io.github.jbellis.jvector.graph.disk.NeighborsScoreCache;
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
import io.github.jbellis.jvector.graph.diversity.DiversityProvider;
import io.github.jbellis.jvector.util.Accountable;
import io.github.jbellis.jvector.util.BitSet;
Expand All @@ -33,14 +35,22 @@
import io.github.jbellis.jvector.util.RamUsageEstimator;
import io.github.jbellis.jvector.util.SparseIntMap;
import io.github.jbellis.jvector.util.ThreadSafeGrowableBitSet;
import io.github.jbellis.jvector.graph.diversity.VamanaDiversityProvider;
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
import io.github.jbellis.jvector.util.*;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import org.agrona.collections.IntArrayList;

import java.io.DataOutput;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicIntegerArray;
import java.util.concurrent.atomic.AtomicReference;
Expand All @@ -67,6 +77,7 @@ public class OnHeapGraphIndex implements MutableGraphIndex {
private final CompletionTracker completions;
private final ThreadSafeGrowableBitSet deletedNodes = new ThreadSafeGrowableBitSet(0);
private final AtomicInteger maxNodeId = new AtomicInteger(-1);
private final int dimension;

// Maximum number of neighbors (edges) per node per layer
final List<Integer> maxDegrees;
Expand All @@ -76,9 +87,10 @@ public class OnHeapGraphIndex implements MutableGraphIndex {

private volatile boolean allMutationsCompleted = false;

OnHeapGraphIndex(List<Integer> maxDegrees, double overflowRatio, DiversityProvider diversityProvider) {
OnHeapGraphIndex(int dimension, List<Integer> maxDegrees, double overflowRatio, DiversityProvider diversityProvider) {
this.overflowRatio = overflowRatio;
this.maxDegrees = new IntArrayList();
this.dimension = dimension;
setDegrees(maxDegrees);
entryPoint = new AtomicReference<>();
this.completions = new CompletionTracker(1024);
Expand Down Expand Up @@ -225,6 +237,11 @@ public NodesIterator getNodes(int level) {
layers.get(level).size());
}

@Override
public int getDimension() {
return dimension;
}

@Override
public IntStream nodeStream(int level) {
var layer = layers.get(level);
Expand Down Expand Up @@ -293,6 +310,10 @@ public View getView() {
}
}

public FrozenView getFrozenView() {
return new FrozenView();
}

public ThreadSafeGrowableBitSet getDeletedNodes() {
return deletedNodes;
}
Expand Down Expand Up @@ -441,7 +462,7 @@ public boolean hasNext() {
}
}

private class FrozenView implements View {
public class FrozenView implements View {
@Override
public NodesIterator getNeighborsIterator(int level, int node) {
return OnHeapGraphIndex.this.getNeighborsIterator(level, node);
Expand Down Expand Up @@ -598,4 +619,69 @@ private void ensureCapacity(int node) {
}
}
}

/**
* Converts an OnDiskGraphIndex to an OnHeapGraphIndex by copying all nodes, their levels, and neighbors,
* along with other configuration details, from disk-based storage to heap-based storage.
*
* @param immutableGraphIndex the disk-based index to be converted
* @param perLevelNeighborsScoreCache the cache containing pre-computed neighbor scores,
* organized by levels and nodes.
* @param bsp The build score provider to be used for
* @param overflowRatio usually 1.2f
* @param alpha usually 1.2f
* @return an OnHeapGraphIndex that is equivalent to the provided OnDiskGraphIndex but operates in heap memory
* @throws IOException if an I/O error occurs during the conversion process
*/
public static OnHeapGraphIndex convertToHeap(ImmutableGraphIndex immutableGraphIndex,
NeighborsScoreCache perLevelNeighborsScoreCache,
BuildScoreProvider bsp,
float overflowRatio,
float alpha) throws IOException {

// Create a new OnHeapGraphIndex with the appropriate configuration
List<Integer> maxDegrees = new ArrayList<>();
for (int level = 0; level <= immutableGraphIndex.getMaxLevel(); level++) {
maxDegrees.add(immutableGraphIndex.getDegree(level));
}

OnHeapGraphIndex heapIndex = new OnHeapGraphIndex(
immutableGraphIndex.getDimension(),
maxDegrees,
overflowRatio, // overflow ratio
new VamanaDiversityProvider(bsp, alpha) // diversity provider - can be null for basic usage
);

// Copy all nodes and their connections from disk to heap
try (var view = immutableGraphIndex.getView()) {
// Copy nodes level by level
for (int level = 0; level <= immutableGraphIndex.getMaxLevel(); level++) {
final NodesIterator nodesIterator = immutableGraphIndex.getNodes(level);
final Map<Integer, NodeArray> levelNeighborsScoreCache = perLevelNeighborsScoreCache.getNeighborsScoresInLevel(level);
if (levelNeighborsScoreCache == null) {
throw new IllegalStateException("No neighbors score cache found for level " + level);
}
if (nodesIterator.size() != levelNeighborsScoreCache.size()) {
throw new IllegalStateException("Neighbors score cache size mismatch for level " + level +
". Expected (currently in index): " + nodesIterator.size() + ", but got (in cache): " + levelNeighborsScoreCache.size());
}

while (nodesIterator.hasNext()) {
int nodeId = nodesIterator.next();

// Copy neighbors
final NodeArray neighbors = levelNeighborsScoreCache.get(nodeId).copy();

// Add the node with its neighbors
heapIndex.connectNode(level, nodeId, neighbors);
heapIndex.markComplete(new NodeAtLevel(level, nodeId));
}
}

// Set the entry point
heapIndex.updateEntryNode(view.entryNode());
}

return heapIndex;
}
}
Loading
Loading