Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions diskann-disk/src/search/pq/pq_scratch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ impl PQScratch {
self.query_scratch.copy_from_slice(&query[..dim]);
Ok(())
}

/// Return the largest number of PQ vectors whose distances can be computed using this
/// scratch data structure.
pub(crate) fn max_vectors(&self) -> usize {
self.aligned_dist_scratch.len()
}
}

#[cfg(test)]
Expand Down Expand Up @@ -112,6 +118,8 @@ mod tests {
0
);

assert_eq!(pq_scratch.max_vectors(), graph_degree);

// Test set() method
let query: Vec<f32> = (1..=dim).map(|i| i as f32).collect();
pq_scratch.set(&query).unwrap();
Expand Down
37 changes: 29 additions & 8 deletions diskann-disk/src/search/provider/disk_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -924,21 +924,42 @@ where
let mut accessor = strategy
.search_accessor(provider, &DefaultContext)
.into_ann_result()?;
let computer = accessor.build_query_computer(query).into_ann_result()?;

// Derive the batch size from the scratch data structure. Providing too many vectors
// will panic.
let batch_size = accessor.scratch.pq_scratch.max_vectors();

// This check should always hold since `graph_degree` comes from
// `diskann::graph::Config` and is forced to be non-zero. But this is defensive
// against misconfiguration.
if batch_size == 0 {
return Err(ANNError::message(
diskann::ANNErrorKind::IndexError,
"pq scratch must support at least one vector",
));
}

let mut id_buffer = Vec::with_capacity(batch_size);

let mut best = NeighborPriorityQueue::new(neighbors_before_reranking);
let mut cmps = 0u32;

let num_points = provider.num_points as u32;
for id in 0..num_points {
if vector_filter(&id) {
let element = accessor.get_element(id).await.into_ann_result()?;
let dist = computer.evaluate_similarity(element);
best.insert(Neighbor::new(id, dist));
cmps += 1;
let mut iter = (0..provider.num_points as u32).filter(vector_filter);
loop {
id_buffer.clear();
id_buffer.extend(iter.by_ref().take(batch_size));

if id_buffer.is_empty() {
break;
}
Comment thread
hildebrandmw marked this conversation as resolved.

accessor.pq_distances(&id_buffer, |dist, id| best.insert(Neighbor::new(id, dist)))?;
Comment thread
hildebrandmw marked this conversation as resolved.
cmps += id_buffer.len() as u32;
}

// FIXME: This is a temporary bridge. We don't really need the query computer, but
// we do need to satisfy the trait definition until PR 1067 lands.
let computer = accessor.build_query_computer(query).into_ann_result()?;
let result_count = strategy
.default_post_processor()
.post_process(&mut accessor, query, &computer, best.iter(), output)
Expand Down
Loading