Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

SET(CMAKE_CXX_FLAGS "-Wall -Ofast -Wextra -lrt -march=native -fpic -fopenmp -ftree-vectorize -fexceptions")

add_subdirectory(sample)
add_subdirectory(sample/cpp)

option(RABITQ_BUILD_PYTHON_BINDINGS "Build Python bindings" OFF)

if(RABITQ_BUILD_PYTHON_BINDINGS)
add_subdirectory(python_bindings)
endif()

add_library(rabitq_headers INTERFACE)

Expand Down
20 changes: 20 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[build-system]
requires = ["scikit-build-core>=0.10.0", "pybind11>=2.12", "numpy"]
build-backend = "scikit_build_core.build"

[project]
name = "rabitqlib"
version = "0.1.0"
description = "RaBitQ Python bindings for HNSW, IVF, and SymQG"
readme = "README.md"
requires-python = ">=3.9"

[tool.scikit-build]
# Disable directory auto-discovery so we can map it via CMake's install target
wheel.packages = []

# Pass the toggle to your root CMake file automatically on pip install
cmake.args = [
"-DRABITQ_BUILD_PYTHON_BINDINGS=ON",
"-DCMAKE_BUILD_TYPE=Release"
]
39 changes: 39 additions & 0 deletions python_bindings/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
cmake_minimum_required(VERSION 3.15)

# 1. Locate pybind11 matching the pip environment
find_package(pybind11 CONFIG REQUIRED)

# 2. Grab the NumPy include path from the active Python environment
execute_process(
COMMAND "${Python_EXECUTABLE}" "-c" "import numpy; print(numpy.get_include())"
OUTPUT_VARIABLE NUMPY_INCLUDE_DIR
OUTPUT_STRIP_TRAILING_WHITESPACE
)

# 3. Find OpenMP to ensure parallel batched searches compile correctly
find_package(OpenMP REQUIRED)

# 4. Define the shared pybind11 module
pybind11_add_module(_rabitqlib
rabitq_bindings.cpp
hnsw_bindings.cpp
ivf_bindings.cpp
symqg_bindings.cpp
)

# 5. Mirror the exact include paths from your old setup.py
target_include_directories(_rabitqlib PRIVATE
${NUMPY_INCLUDE_DIR}
${CMAKE_CURRENT_SOURCE_DIR} # For bindings_common.hpp
${PROJECT_SOURCE_DIR}/include # For rabitqlib core headers
)

# 6. Link OpenMP flags natively
target_link_libraries(_rabitqlib PRIVATE
OpenMP::OpenMP_CXX
)

# 7. Map files to a 'rabitqlib' directory
# inside the target wheel, completely bypassing your local layout.
install(TARGETS _rabitqlib DESTINATION rabitqlib)
install(FILES __init__.py DESTINATION rabitqlib)
9 changes: 9 additions & 0 deletions python_bindings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Python bindings for RaBitQLib's three index types.

This module re-exports symbols from the compiled `rabitqlib` extension
so code can import using `from python_bindings import ...` if needed.
"""

from ._rabitqlib import HnswIndex, IvfIndex, SymqgIndex

__all__ = ["HnswIndex", "IvfIndex", "SymqgIndex"]
73 changes: 73 additions & 0 deletions python_bindings/bindings_common.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#pragma once

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <stdexcept>
#include <string>

#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>

#include "rabitqlib/defines.hpp"
#include "rabitqlib/utils/rotator.hpp"

namespace py = pybind11;

namespace rabitqlib::python_bindings {

inline rabitqlib::MetricType metric_from_string(const std::string& metric) {
if (metric == "l2") {
return rabitqlib::METRIC_L2;
}
if (metric == "ip" || metric == "innerproduct") {
return rabitqlib::METRIC_IP;
}
throw std::invalid_argument("Unsupported metric. Use 'l2' or 'ip'.");
}

inline std::string metric_to_string(rabitqlib::MetricType metric) {
return metric == rabitqlib::METRIC_IP ? "ip" : "l2";
}

inline rabitqlib::RotatorType rotator_from_string(const std::string& method) {
if (method == "matrix") {
return rabitqlib::RotatorType::MatrixRotator;
}
if (method == "fht_kac" || method == "fht") {
return rabitqlib::RotatorType::FhtKacRotator;
}
throw std::invalid_argument("Unsupported rotator method. Use 'fht_kac' or 'matrix'.");
}

template <typename T>
inline py::array_t<T, py::array::c_style | py::array::forcecast> ensure_2d_array(
py::handle value,
const char* name
) {
auto array = py::array_t<T, py::array::c_style | py::array::forcecast>::ensure(value);
if (!array) {
throw std::invalid_argument(std::string(name) + " must be a NumPy array");
}
if (array.ndim() != 2) {
throw std::invalid_argument(std::string(name) + " must be a 2D NumPy array");
}
return array;
}

template <typename T>
inline py::array_t<T, py::array::c_style | py::array::forcecast> ensure_1d_array(
py::handle value,
const char* name
) {
auto array = py::array_t<T, py::array::c_style | py::array::forcecast>::ensure(value);
if (!array) {
throw std::invalid_argument(std::string(name) + " must be a NumPy array");
}
if (array.ndim() != 1) {
throw std::invalid_argument(std::string(name) + " must be a 1D NumPy array");
}
return array;
}

} // namespace rabitqlib::python_bindings
201 changes: 201 additions & 0 deletions python_bindings/hnsw_bindings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
#include <algorithm>
#include <memory>
#include <string>
#include <cstring>
#include <vector>

#include <pybind11/stl.h>

#include "bindings_common.hpp"
#include "rabitqlib/index/hnsw/hnsw.hpp"

namespace py = pybind11;

namespace rabitqlib::python_bindings {

class HnswIndex {
public:
HnswIndex(
size_t dim,
size_t max_elements,
size_t M,
size_t ef_construction,
size_t nbits,
const std::string& metric = "l2",
size_t random_seed = 100
)
: dim_(dim)
, max_elements_(max_elements)
, M_(M)
, ef_construction_(ef_construction)
, nbits_(nbits)
, metric_(metric_from_string(metric))
, random_seed_(random_seed)
, index_(std::make_unique<rabitqlib::hnsw::HierarchicalNSW>(
max_elements,
dim,
nbits,
M,
ef_construction,
random_seed,
metric_
)) {}

void build(
py::handle data,
py::handle centroids,
py::handle cluster_ids,
size_t num_threads = 1,
bool fast_quantization = false
) {
auto data_array = ensure_2d_array<float>(data, "data");
auto centroids_array = ensure_2d_array<float>(centroids, "centroids");
auto cluster_ids_array = ensure_1d_array<rabitqlib::PID>(cluster_ids, "cluster_ids");

if (static_cast<size_t>(data_array.shape(1)) != dim_) {
throw std::invalid_argument("data dimension does not match index dim");
}
if (static_cast<size_t>(centroids_array.shape(1)) != dim_) {
throw std::invalid_argument("centroid dimension does not match index dim");
}
if (static_cast<size_t>(cluster_ids_array.shape(0)) != static_cast<size_t>(data_array.shape(0))) {
throw std::invalid_argument("cluster_ids length must match number of rows in data");
}

const size_t num_clusters = static_cast<size_t>(centroids_array.shape(0));
num_clusters_ = num_clusters;

// Ensure cluster_ids are writable for the C++ API by making a copy
std::vector<rabitqlib::PID> cluster_ids_vec(static_cast<size_t>(cluster_ids_array.shape(0)));
std::memcpy(cluster_ids_vec.data(), cluster_ids_array.data(), cluster_ids_vec.size() * sizeof(rabitqlib::PID));

py::gil_scoped_release release;
index_->construct(
num_clusters,
centroids_array.data(),
static_cast<size_t>(data_array.shape(0)),
data_array.data(),
cluster_ids_vec.data(),
num_threads,
fast_quantization
);
built_ = true;
}

py::tuple search(py::handle queries, size_t k, size_t ef = 0, size_t num_threads = 1) {
auto query_array = ensure_2d_array<float>(queries, "queries");
if (dim_ != 0 && static_cast<size_t>(query_array.shape(1)) != dim_) {
throw std::invalid_argument("query dimension does not match index dim");
}
if (ef == 0) {
ef = std::max<size_t>(k, 10);
}

const auto shape = std::vector<ssize_t>{
static_cast<ssize_t>(query_array.shape(0)), static_cast<ssize_t>(k)};
auto ids = py::array_t<rabitqlib::PID>(shape);
auto dists = py::array_t<float>(shape);
auto ids_buf = ids.mutable_unchecked<2>();
auto dists_buf = dists.mutable_unchecked<2>();
std::vector<std::vector<std::pair<float, rabitqlib::PID>>> results;
{
py::gil_scoped_release release;
results = index_->search(
query_array.data(),
static_cast<size_t>(query_array.shape(0)),
k,
ef,
num_threads
);
}


for (ssize_t i = 0; i < static_cast<ssize_t>(results.size()); ++i) {
for (
ssize_t j = 0;
j < static_cast<ssize_t>(std::min<size_t>(k, results[static_cast<size_t>(i)].size()));
++j
) {
ids_buf(i, j) = results[static_cast<size_t>(i)][static_cast<size_t>(j)].second;
dists_buf(i, j) = results[static_cast<size_t>(i)][static_cast<size_t>(j)].first;
}
}
return py::make_tuple(ids, dists);
}

void save(const std::string& path) const {
py::gil_scoped_release release;
index_->save(path.c_str());
}

static HnswIndex load(const std::string& path) {
HnswIndex wrapper;
wrapper.index_ = std::make_unique<rabitqlib::hnsw::HierarchicalNSW>();
py::gil_scoped_release release;
wrapper.index_->load(path.c_str());
wrapper.dim_ = wrapper.index_->dimension();
wrapper.max_elements_ = wrapper.index_->max_elements();
wrapper.M_ = wrapper.index_->M();
wrapper.ef_construction_ = wrapper.index_->ef_construction();
wrapper.nbits_ = wrapper.index_->nbits();
wrapper.num_clusters_ = wrapper.index_->num_clusters();
wrapper.metric_ = wrapper.index_->metric_type();
wrapper.built_ = true;
return wrapper;
}

[[nodiscard]] size_t dim() const { return dim_; }
[[nodiscard]] size_t max_elements() const { return max_elements_; }
[[nodiscard]] size_t nbits() const { return nbits_; }
[[nodiscard]] bool is_built() const { return built_; }
[[nodiscard]] size_t num_clusters() const { return num_clusters_; }

private:
HnswIndex() = default;

size_t dim_ = 0;
size_t max_elements_ = 0;
size_t M_ = 0;
size_t ef_construction_ = 0;
size_t nbits_ = 0;
rabitqlib::MetricType metric_ = rabitqlib::METRIC_L2;
size_t random_seed_ = 100;
size_t num_clusters_ = 0;
bool built_ = false;
std::unique_ptr<rabitqlib::hnsw::HierarchicalNSW> index_;
};

} // namespace rabitqlib::python_bindings

// Register into combined module
void register_hnsw(py::module_ &m) {
using namespace rabitqlib::python_bindings;

py::class_<HnswIndex>(m, "HnswIndex")
.def(py::init<size_t, size_t, size_t, size_t, size_t, const std::string&, size_t>(),
py::arg("dim"),
py::arg("max_elements"),
py::arg("M") = 16,
py::arg("ef_construction") = 200,
py::arg("nbits") = 8,
py::arg("metric") = "l2",
py::arg("random_seed") = 100)
.def("build", &HnswIndex::build,
py::arg("data"),
py::arg("centroids"),
py::arg("cluster_ids"),
py::arg("num_threads") = 1,
py::arg("fast_quantization") = false)
.def("search", &HnswIndex::search,
py::arg("queries"),
py::arg("k"),
py::arg("ef") = 0,
py::arg("num_threads") = 1)
.def("save", &HnswIndex::save, py::arg("path"))
.def_static("load", &HnswIndex::load, py::arg("path"))
.def_property_readonly("dim", &HnswIndex::dim)
.def_property_readonly("max_elements", &HnswIndex::max_elements)
.def_property_readonly("nbits", &HnswIndex::nbits)
.def_property_readonly("num_clusters", &HnswIndex::num_clusters)
.def_property_readonly("is_built", &HnswIndex::is_built);
}
Loading