Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
66d7fd3
combine impls
tarang-jain Apr 10, 2026
07707af
Multi-GPU Batched KMeans
viclafargue Apr 13, 2026
efc270f
Merge branch 'main' into mg-batched-kmeans
viclafargue Apr 13, 2026
0a09e6f
rm inertia_check
tarang-jain Apr 13, 2026
99a5730
change to warning
tarang-jain Apr 13, 2026
a077406
style
tarang-jain Apr 13, 2026
d659875
add init_size param
tarang-jain Apr 13, 2026
ec2e8b7
Merge branch 'main' into combine-batch
tarang-jain Apr 13, 2026
03a6473
docs
tarang-jain Apr 13, 2026
42a8d9d
Merge branch 'combine-batch' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 13, 2026
86af2fa
rm direct cuda api calls
tarang-jain Apr 13, 2026
d4e4e2c
std::swap instead of raft::copy
tarang-jain Apr 14, 2026
0819af5
cache batch norms
tarang-jain Apr 14, 2026
e0f079c
centroid norms can also be cached per iteration
tarang-jain Apr 14, 2026
c2f7390
mg n_iter
tarang-jain Apr 14, 2026
b9c3102
pre-commit
tarang-jain Apr 14, 2026
e3956c1
do not break c abi
tarang-jain Apr 14, 2026
986d78a
Merge branch 'main' into combine-batch
tarang-jain Apr 14, 2026
7197b71
cluster_cost on device
viclafargue Apr 14, 2026
84ab315
Updated testing
viclafargue Apr 14, 2026
47d4b94
templating
viclafargue Apr 15, 2026
a8e1d26
Merge branch 'main' into combine-batch
tarang-jain Apr 16, 2026
384d054
fix checkWeight
tarang-jain Apr 21, 2026
455b286
merge upstream:
tarang-jain Apr 21, 2026
5462809
Merge branch 'combine-batch' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 21, 2026
6ba759c
fix compilation
tarang-jain Apr 21, 2026
e76eaac
rel_tol
tarang-jain Apr 22, 2026
afbefdf
pass workspace
tarang-jain Apr 22, 2026
e62a63c
Merge branch 'combine-batch' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 22, 2026
e4f08bf
style
tarang-jain Apr 22, 2026
6e4a8f0
Merge branch 'main' of https://github.com/rapidsai/cuvs into combine-…
tarang-jain Apr 22, 2026
4a8a85c
do not use batch scratch space; rm update_centroids
tarang-jain Apr 22, 2026
bbf2a9f
move the debug log
tarang-jain Apr 22, 2026
410092c
add new suffixed param struct
tarang-jain Apr 22, 2026
c515c1e
address pr reviews
tarang-jain Apr 22, 2026
e8e63ab
fix docstring
tarang-jain Apr 22, 2026
30c457c
fix wt_sum warning
tarang-jain Apr 22, 2026
ab96623
rm deprecationwarning and instead add FutureWarning:=
tarang-jain Apr 22, 2026
269f23c
unweighted to never materialize batch weights
tarang-jain Apr 22, 2026
80a22ca
add cpp tests
tarang-jain Apr 23, 2026
ac06b05
update cpp tests
tarang-jain Apr 23, 2026
855624a
Merge branch 'main' into mg-batched-kmeans
viclafargue Apr 23, 2026
0a6748d
refactor
viclafargue Apr 23, 2026
7055272
rename to mnmg_fit
viclafargue Apr 23, 2026
0569340
revert batch norms cache
tarang-jain Apr 23, 2026
8cac63a
increase zero cost threshold
tarang-jain Apr 24, 2026
f6df4ae
apply cuda event plus re-add h_norm_cache
tarang-jain Apr 24, 2026
9fc74b1
rm cosine expanded stuff
tarang-jain Apr 24, 2026
dec3dc4
resolve merge conflicts
tarang-jain Apr 28, 2026
0d030a2
change suffix of the params struct
tarang-jain Apr 28, 2026
b1c034e
replace 06 by 08, add todo and note
tarang-jain Apr 28, 2026
a482495
update to v2
tarang-jain Apr 28, 2026
8ecfdc1
avoid stream sync inside weight sum
tarang-jain Apr 29, 2026
1e1525e
Merge branch 'combine-batch' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 29, 2026
ec22e07
empty
tarang-jain Apr 29, 2026
d2e410d
empty
tarang-jain Apr 29, 2026
b791c38
Merge branch 'main' into combine-batch
tarang-jain Apr 29, 2026
a05a006
new signatures with new struct
tarang-jain Apr 29, 2026
73293cf
Merge branch 'combine-batch' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 29, 2026
880c7b9
Merge branch 'main' of https://github.com/rapidsai/cuvs into combine-…
tarang-jain Apr 30, 2026
e2035ec
revert change to calls in py and rust; add c tests
tarang-jain Apr 30, 2026
e28c200
Merge branch 'main' into combine-batch
tarang-jain May 1, 2026
55bbdad
use to_dlpack
tarang-jain May 5, 2026
9a9b8ee
cache device weights
tarang-jain May 5, 2026
a800b27
rm event
tarang-jain May 5, 2026
3db8582
update names
tarang-jain May 5, 2026
c048352
rename
tarang-jain May 5, 2026
2f968f8
rm docs
tarang-jain May 5, 2026
affe85a
empty
tarang-jain May 5, 2026
c6dea64
fix norm cache
tarang-jain May 5, 2026
7dfab3e
revert changes to minClusterDistanceCompute
tarang-jain May 6, 2026
7a383da
update tests to use mdspan instead of rmm
tarang-jain May 6, 2026
ce6c4b5
Merge branch 'main' into combine-batch
tarang-jain May 6, 2026
5a06a44
Merge branch 'main' into combine-batch
tarang-jain May 6, 2026
28cda6a
Merge branch 'combine-batch' into mg-batched-kmeans
viclafargue May 7, 2026
bfb5290
Addressing review
viclafargue May 7, 2026
add9db1
optimize convergence check
viclafargue May 7, 2026
af606bc
Adressing review
viclafargue May 8, 2026
41c66b8
Merge branch 'main' into mg-batched-kmeans
viclafargue May 8, 2026
f664c2c
results on all ranks for RAFT + small optimization
viclafargue May 8, 2026
6c2c03d
reviews
viclafargue May 11, 2026
7f6d664
Global sampling for init
viclafargue May 11, 2026
f8270e2
SNMG -> MNMG
viclafargue May 11, 2026
bbf0302
Merge branch 'main' into mg-batched-kmeans
viclafargue May 11, 2026
a14a6bc
adding asserts
viclafargue May 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion cpp/include/cuvs/cluster/kmeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,13 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 };
* on the host. Data is processed in GPU-sized batches, streaming from host to device.
* The batch size is controlled by params.streaming_batch_size.
*
* Multi-GPU dispatch is selected automatically based on the handle state:
* - If `raft::resource::is_multi_gpu(handle)` (cuVS SNMG): the full dataset X
* is split across GPUs internally with an OpenMP parallel region and NCCL.
* - If `raft::resource::comms_initialized(handle)` (Dask/Ray/MPI): X is treated as
* this worker's partition, and RAFT communicators are used for collectives.
* - Otherwise: single-GPU batched k-means.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/cluster/kmeans.hpp>
Expand Down Expand Up @@ -208,7 +215,8 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 };
* raft::make_host_scalar_view(&n_iter));
* @endcode
*
* @param[in] handle The raft handle.
* @param[in] handle The raft handle. When a multi-GPU resource is
* attached, multi-GPU dispatch is used automatically.
* @param[in] params Parameters for KMeans model. Batch size is read from
* params.streaming_batch_size.
* @param[in] X Training instances on HOST memory. The data must
Expand Down
15 changes: 8 additions & 7 deletions cpp/src/cluster/detail/kmeans_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -596,26 +596,27 @@ void compute_centroid_shift(raft::resources const& handle,
* @brief Evaluate convergence criteria entirely on device.
*
* Checks the cost-ratio and centroid-shift stopping conditions and writes
* a boolean result (0 or 1) into @p done_flag. Also advances
* @p prior_clustering_cost to the current cost for the next iteration.
* 0 or 1 into @p done_flag, and advances @p prior_clustering_cost.
* @p FlagT is deduced from @p done_flag (default `int`); MG callers pass
* `int64_t` for NCCL allreduce compatibility.
*/
template <typename DataT>
template <typename DataT, typename FlagT = int>
__device__ void check_convergence(raft::device_scalar_view<const DataT> clustering_cost,
raft::device_scalar_view<DataT> prior_clustering_cost,
raft::device_scalar_view<const DataT> sqrd_norm_error,
DataT tol,
int n_iter,
raft::device_scalar_view<int> done_flag)
raft::device_scalar_view<FlagT> done_flag)
{
DataT cur_cost = *clustering_cost.data_handle();
DataT norm_err = *sqrd_norm_error.data_handle();
int done = 0;
FlagT done = FlagT{0};

if (cur_cost != DataT{0} && n_iter > 1) {
DataT delta = cur_cost / *prior_clustering_cost.data_handle();
if (delta > DataT{1} - tol) done = 1;
if (delta > DataT{1} - tol) done = FlagT{1};
}
if (norm_err < tol) done = 1;
if (norm_err < tol) done = FlagT{1};

*prior_clustering_cost.data_handle() = cur_cost;
*done_flag.data_handle() = done;
Expand Down
36 changes: 16 additions & 20 deletions cpp/src/cluster/detail/kmeans_mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -463,35 +463,31 @@ void checkWeights(const raft::resources& handle,
raft::device_vector_view<DataT, IndexT> weight)
{
cudaStream_t stream = raft::resource::get_cuda_stream(handle);
rmm::device_scalar<DataT> wt_aggr(stream);
auto d_wt_sum = raft::make_device_scalar<DataT>(handle, DataT{0});

const auto& comm = raft::resource::get_comms(handle);

auto n_samples = weight.extent(0);
raft::linalg::mapThenSumReduce(
wt_aggr.data(), n_samples, raft::identity_op{}, stream, weight.data_handle());
d_wt_sum.data_handle(), n_samples, raft::identity_op{}, stream, weight.data_handle());

comm.allreduce<DataT>(wt_aggr.data(), // sendbuff
wt_aggr.data(), // recvbuff
1, // count
comm.allreduce<DataT>(d_wt_sum.data_handle(), // sendbuff
d_wt_sum.data_handle(), // recvbuff
1, // count
raft::comms::op_t::SUM,
stream);
DataT wt_sum = wt_aggr.value(stream);
raft::resource::sync_stream(handle, stream);
RAFT_EXPECTS(wt_sum > DataT{0}, "invalid parameter (sum of sample weights must be positive)");

if (wt_sum != n_samples) {
CUVS_LOG_KMEANS(handle,
"[Warning!] KMeans: normalizing the user provided sample weights to "
"sum up to %d samples",
n_samples);

raft::linalg::map(handle,
weight,
raft::compose_op(raft::mul_const_op<DataT>{static_cast<DataT>(n_samples)},
raft::div_const_op<DataT>{wt_sum}),
raft::make_const_mdspan(weight));
}
// Normalize weights so they sum to n_samples (per rank). Reading the sum from
// a device pointer avoids a host copy / stream sync. When the sum already
// equals n_samples this is a numerical no-op (matches single-GPU behavior).
const DataT* d_wt_sum_ptr = d_wt_sum.data_handle();
raft::linalg::map(
handle,
weight,
[n_samples, d_wt_sum_ptr] __device__(DataT w) {
return w * static_cast<DataT>(n_samples) / *d_wt_sum_ptr;
},
raft::make_const_mdspan(weight));
}

template <typename DataT, typename IndexT>
Expand Down
Loading
Loading