Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -781,21 +781,35 @@ inline auto get_max_coarse_batch_size(raft::resources const& res,
const search_params& params,
uint32_t n_probes,
uint32_t n_lists,
uint32_t n_queries) -> uint32_t
uint32_t n_queries,
uint32_t dim_ext,
uint32_t rot_dim) -> uint32_t
{
size_t data_size = 4;
size_t gemm_elem_size;
size_t qc_elem_size;
switch (params.coarse_search_dtype) {
case CUDA_R_32F: data_size = 4; break;
case CUDA_R_16F: data_size = 2; break;
case CUDA_R_8I: data_size = 1; break;
case CUDA_R_32F:
gemm_elem_size = 4;
qc_elem_size = 4;
break;
case CUDA_R_16F:
gemm_elem_size = 2;
qc_elem_size = 2;
break;
case CUDA_R_8I:
gemm_elem_size = 1;
qc_elem_size = 4;
break;
default: RAFT_FAIL("Unexpected coarse_search_dtype (%d)", int(params.coarse_search_dtype));
}
// How much data we allocate for coarse GEMM.
// This is NOT all memory we need, as a rule of thumb max it out to half of the workspace.
// We don't reach this limit by default, but only when we increase the max_internal_batch_size by
// a lot.
auto bytes_per_query = static_cast<size_t>(n_probes + n_lists) * data_size;
auto max_per_ws = raft::resource::get_workspace_free_bytes(res) / bytes_per_query;
// Persistent allocations that live for the entire search call.
auto persistent_per_query = static_cast<size_t>(dim_ext) * gemm_elem_size +
static_cast<size_t>(rot_dim) * sizeof(float) +
static_cast<size_t>(n_probes) * sizeof(uint32_t);
// Transient allocations during coarse search (select_clusters): qc_distances + cluster_dists.
auto transient_per_query = static_cast<size_t>(n_lists + n_probes) * qc_elem_size;
auto total_per_query = persistent_per_query + transient_per_query;
auto max_per_ws = raft::resource::get_workspace_free_bytes(res) / total_per_query;
return std::max<uint32_t>(
1,
std::min<uint32_t>(max_per_ws / 2,
Expand Down Expand Up @@ -889,8 +903,8 @@ inline void search(raft::resources const& handle,

// Maximum number of query vectors to search at the same time.
// Number of queries in the outer loop, which includes query transform and coarse search.
const auto max_bs_outer =
get_max_coarse_batch_size(handle, params, n_probes, index.n_lists(), n_queries);
const auto max_bs_outer = get_max_coarse_batch_size(
handle, params, n_probes, index.n_lists(), n_queries, dim_ext, index.rot_dim());
// Number of queries in the inner loop, which includes the fine search;
// This is usually smaller than the outer loop when the non-fused kernel has to keep intermediate
// results in the device memory.
Expand Down
Loading