Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions c_api/IndexIVF_c_ex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "macros_impl.h"

using faiss::IndexIVF;
using faiss::IndexIVFRaBitQ;
using faiss::SearchParameters;
using faiss::SearchParametersIVF;

Expand Down Expand Up @@ -206,3 +207,41 @@ int faiss_SearchParametersRaBitQ_new_with(
}
CATCH_AND_HANDLE
}

int faiss_IndexIVFRaBitQ_compute_distance_with_precomputed(
FaissIndexIVF* index,
idx_t list_no,
const float* x,
idx_t n,
const uint8_t* codes,
float* dists,
uint8_t* query_bp,
size_t* query_bp_size) {
try {
auto* rabitq_index = dynamic_cast<IndexIVFRaBitQ*>(
reinterpret_cast<IndexIVF*>(index));
FAISS_THROW_IF_NOT_MSG(
rabitq_index,
"index is not an IndexIVFRaBitQ instance");
rabitq_index->compute_distance_to_codes_with_precomputed(
list_no, x, n, codes, dists,
query_bp, query_bp_size);
return 0;
}
CATCH_AND_HANDLE
}

int faiss_IndexIVFRaBitQ_query_bitplanes_size(
FaissIndexIVF* index,
size_t* size) {
try {
auto* rabitq_index = dynamic_cast<IndexIVFRaBitQ*>(
reinterpret_cast<IndexIVF*>(index));
FAISS_THROW_IF_NOT_MSG(
rabitq_index,
"index is not an IndexIVFRaBitQ instance");
*size = rabitq_index->query_bitplanes_size();
return 0;
}
CATCH_AND_HANDLE
}
40 changes: 40 additions & 0 deletions c_api/IndexIVF_c_ex.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,46 @@ int faiss_IndexIVF_compute_distance_to_codes_for_list(
float* dists,
float* dist_table);

/*
Compute distance to codes with optional precomputed query
bitplanes.

query_bp: caller-allocated buffer of
faiss_IndexIVFRaBitQ_query_bitplanes_size() bytes.
query_bp_size: in/out. On entry: 0 or mismatched header =
(re)compute and fill query_bp; matching size + valid header =
reuse existing precomputed data. On return: set to actual
bytes written.

@param index - the IVF index (must be IndexIVFRaBitQ)
@param list_no - list number for inverted list
@param x - input query vector
@param n - number of codes
@param codes - input codes
@param dists - output computed distances
@param query_bp - precomputed query bitplanes buffer
@param query_bp_size - in/out: 0 to compute, >0 to reuse
*/
int faiss_IndexIVFRaBitQ_compute_distance_with_precomputed(
FaissIndexIVF* index,
idx_t list_no,
const float* x,
idx_t n,
const uint8_t* codes,
float* dists,
uint8_t* query_bp,
size_t* query_bp_size);

/*
Get the byte size needed for precomputed query bitplanes buffer.

@param index - the IVF index (must be IndexIVFRaBitQ)
@param size - output: required buffer size in bytes
*/
int faiss_IndexIVFRaBitQ_query_bitplanes_size(
FaissIndexIVF* index,
size_t* size);

/*
Get centroid information and cardinality for all centroids in an IVF index.

Expand Down
171 changes: 166 additions & 5 deletions faiss/IndexIVFRaBitQ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@

#include <cstddef>
#include <cstdint>
#include <cstring>
#include <limits>
#include <memory>
#include <vector>

#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/RaBitQUtils.h>
#include <faiss/impl/RaBitQuantizer.h>
#include <faiss/utils/rabitq_simd.h>

namespace faiss {

Expand Down Expand Up @@ -364,12 +367,12 @@ void IndexIVFRaBitQ::compute_distance_to_codes_for_list(
std::vector<float> centroid(d);
quantizer->reconstruct(list_no, centroid.data());

// Note: "centered" query quantization is a per-search parameter in Faiss.
// compute_distance_to_codes_for_list does not take IVFSearchParameters, so
// we use centered=false here (consistent with get_distance_computer()). In
// future, we can look into setting centered and qb per call if needed.
// NOTE: centered=false is hardcoded here and in
// compute_distance_to_codes_with_precomputed. Do not flip without also
// updating the inline distance loop in the precomputed path.
std::unique_ptr<FlatCodesDistanceComputer> dc(
rabitq.get_distance_computer(qb, centroid.data(), /*centered=*/false));
rabitq.get_distance_computer(
qb, centroid.data(), /*centered=*/false));
dc->set_query(x);

const uint8_t* code = codes;
Expand All @@ -378,6 +381,164 @@ void IndexIVFRaBitQ::compute_distance_to_codes_for_list(
}
}

size_t IndexIVFRaBitQ::query_bitplanes_size() const {
FAISS_THROW_IF_NOT(qb > 0 && qb <= 8);
return sizeof(PrecomputedQueryHeader) + rabitq.precomputed_query_size(qb);
}

void IndexIVFRaBitQ::compute_distance_to_codes_with_precomputed(
idx_t list_no,
const float* x,
idx_t n,
const uint8_t* codes,
float* dists,
uint8_t* query_bp,
size_t* query_bp_size) const {
FAISS_THROW_IF_NOT(qb > 0 && qb <= 8);
FAISS_THROW_IF_NOT(n >= 0);
FAISS_THROW_IF_NOT(list_no >= 0 && (size_t)list_no < nlist);
FAISS_THROW_IF_NOT(x != nullptr);
FAISS_THROW_IF_NOT(codes != nullptr);
FAISS_THROW_IF_NOT(dists != nullptr);
FAISS_THROW_IF_NOT(code_size > 0);
FAISS_THROW_IF_NOT(query_bp != nullptr);
FAISS_THROW_IF_NOT(query_bp_size != nullptr);
if (n == 0) {
return;
}
FAISS_THROW_IF_NOT(
(size_t)n <= (std::numeric_limits<size_t>::max)() / code_size);

size_t header_size = sizeof(PrecomputedQueryHeader);
size_t byte_dim = (d + 7) / 8;
size_t required_bp_size = query_bitplanes_size();

// A buffer is reusable only if:
// (1) it has the size we expect, AND
// (2) its header identifies the same (list_no, qb, nb_bits, d, magic,
// version) we're being called with.
// Either mismatch triggers a fresh compute. This catches every silent
// cross-misuse class: wrong list, wrong index, wrong qb, stale layout.
bool use_precomputed = false;
if (*query_bp_size == required_bp_size) {
const auto* hdr =
reinterpret_cast<const PrecomputedQueryHeader*>(query_bp);
use_precomputed =
hdr->magic == PrecomputedQueryHeader::kMagic &&
hdr->version == PrecomputedQueryHeader::kVersion &&
hdr->nb_bits == rabitq.nb_bits &&
hdr->qb == qb &&
hdr->d == (uint32_t)d &&
hdr->list_no == list_no;
}

if (!use_precomputed) {
std::vector<float> centroid(d);
quantizer->reconstruct(list_no, centroid.data());

// Stamp the header first. compute_query_precomputed writes after it.
auto* hdr = reinterpret_cast<PrecomputedQueryHeader*>(query_bp);
hdr->magic = PrecomputedQueryHeader::kMagic;
hdr->version = PrecomputedQueryHeader::kVersion;
hdr->nb_bits = (uint16_t)rabitq.nb_bits;
hdr->qb = qb;
hdr->pad[0] = hdr->pad[1] = hdr->pad[2] = 0;
hdr->d = (uint32_t)d;
hdr->list_no = list_no;

size_t written = 0;
// NOTE: centered=false is hardcoded here and in
// compute_distance_to_codes_for_list. Do not flip without also
// updating the inline distance loop below (which only implements the
// !centered branch).
rabitq.compute_query_precomputed(
x, qb, centroid.data(), /*centered=*/false,
query_bp + header_size, &written);

FAISS_THROW_IF_NOT(written == rabitq.precomputed_query_size(qb));
*query_bp_size = header_size + written;
}
FAISS_THROW_IF_NOT(*query_bp_size == required_bp_size);

Comment thread
Nischal1729 marked this conversation as resolved.
// Zero-copy: read directly from the precomputed buffer past the header.
size_t offset = header_size;

const auto* query_fac = reinterpret_cast<
const PrecomputedQueryScalars*>(query_bp + offset);
offset += sizeof(PrecomputedQueryScalars);

const float* rotated_q =
reinterpret_cast<const float*>(query_bp + offset);
offset += d * sizeof(float);

const uint8_t* rearranged_qq = query_bp + offset;

// Compute distances inline rather than via get_distance_computer +
// set_query + distance_to_code: get_distance_computer alone costs
// ~1.5 us per call, which dominates 1-bit branch at the small batch
// sizes typical in certain workloads.
//
// - Multi-bit branch delegates to
// rabitq_utils::compute_full_multibit_distance (same utility used
// by RaBitQDistanceComputerQ::distance_to_code_full), so changes
// there propagate automatically.
// - 1-bit branch reproduces the math from
// RaBitQDistanceComputerQ::distance_to_code_1bit. Keep in sync.
const size_t ex_bits = rabitq.nb_bits - 1;
const uint8_t* code = codes;
for (idx_t i = 0; i < n; i++, code += code_size) {
if (ex_bits == 0) {
// 1-bit path (same as distance_to_code_1bit)
const uint8_t* binary_data = code;
const auto* base_fac = reinterpret_cast<
const rabitq_utils::SignBitFactors*>(
code + byte_dim);

// SIMD dot products on rearranged bitplanes
auto dot_qo = rabitq::bitwise_and_dot_product(
rearranged_qq, binary_data, byte_dim, qb);
auto sum_q = rabitq::popcount(binary_data, byte_dim);

// Match exact operation sequence from
// RaBitQDistanceComputerQ::distance_to_code_1bit
float final_dot = 0;
final_dot += query_fac->c1 * dot_qo;
final_dot += query_fac->c2 * sum_q;
final_dot -= query_fac->c34;

float pre_dist = base_fac->or_minus_c_l2sqr +
query_fac->qr_to_c_L2sqr -
2 * base_fac->dp_multiplier * final_dot;

if (metric_type == MetricType::METRIC_L2) {
dists[i] = pre_dist;
} else {
dists[i] = -0.5f * (pre_dist - query_fac->qr_norm_L2sqr);
}
} else {
// Multi-bit path (same as distance_to_code_full)
const uint8_t* binary_data = code;
size_t fac_offset =
byte_dim + sizeof(rabitq_utils::SignBitFactorsWithError);
const uint8_t* ex_code = code + fac_offset;
const auto* ex_fac = reinterpret_cast<
const rabitq_utils::ExtraBitsFactors*>(
ex_code + (d * ex_bits + 7) / 8);

dists[i] = rabitq_utils::compute_full_multibit_distance(
binary_data,
ex_code,
*ex_fac,
rotated_q,
query_fac->qr_to_c_L2sqr,
query_fac->qr_norm_L2sqr,
d,
ex_bits,
metric_type);
}
}
}

struct IVFRaBitDistanceComputer : DistanceComputer {
const float* q = nullptr;
const IndexIVFRaBitQ* parent = nullptr;
Expand Down
52 changes: 52 additions & 0 deletions faiss/IndexIVFRaBitQ.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,58 @@ struct IndexIVFRaBitQ : IndexIVF {
float* dists,
float* dist_table) const override;

/// Self-describing header stamped at the start of the precomputed query
/// buffer. Validated on every reuse so a buffer built for a different
/// list / qb / nb_bits / d fails loud instead of silently producing
/// wrong distances. POD, fixed 24 bytes.
struct PrecomputedQueryHeader {
static constexpr uint32_t kMagic = 0x52424351; // 'RBCQ'
static constexpr uint16_t kVersion = 1; // bump on layout change

uint32_t magic;
uint16_t version;
uint16_t nb_bits;
uint8_t qb;
uint8_t pad[3];
uint32_t d;
int64_t list_no;
};
static_assert(
sizeof(PrecomputedQueryHeader) == 24,
"PrecomputedQueryHeader must be exactly 24 bytes");

/// Compute distances to codes with optional precomputed query state.
///
/// The buffer carries a self-describing header so the function can
/// distinguish a fresh reusable buffer from a stale / wrong-list /
/// wrong-config one without trusting the caller's bookkeeping. The
/// caller MUST allocate at least query_bitplanes_size() bytes.
///
/// @param list_no IVF list the codes belong to.
/// @param x query vector (d floats).
/// @param n number of codes.
/// @param codes input codes, n * code_size bytes.
/// @param dists output distances, n floats.
/// @param query_bp caller-allocated buffer of query_bitplanes_size()
/// bytes.
/// @param query_bp_size in/out: 0 to (re)compute, otherwise must equal
/// query_bitplanes_size() AND the in-buffer header
/// must match (list_no, qb, nb_bits, d). Either
/// mismatch triggers a fresh compute. On return:
/// set to bytes written.
void compute_distance_to_codes_with_precomputed(
idx_t list_no,
const float* x,
idx_t n,
const uint8_t* codes,
float* dists,
uint8_t* query_bp,
size_t* query_bp_size) const;

/// Returns the byte size needed for precomputed query bitplanes buffer.
/// Includes the self-describing header.
size_t query_bitplanes_size() const;

// unfortunately
DistanceComputer* get_distance_computer() const override;
};
Expand Down
Loading