Multi-GPU Batched KMeans#2017
Conversation
|
Here are some instructions to test the Multi-GPU Batched KMeans API with RAFT comms (to be used with Ray/Dask) : RAFT comms (Ray/Dask) demo codeCompilation commandLaunch command
|
…nto combine-batch
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ |
There was a problem hiding this comment.
Can we get rid of this file entirely and combine with regular mg kmeans (just as we are doing in PR #2015)? Is that possible?
There was a problem hiding this comment.
Also, MNMG should be able to reuse the snmg_fit function (for a single worker) as is, right? Except that the nccl reduce macro will be replaced by something like comms.allreduce()
There was a problem hiding this comment.
I refactored the code to work flawlessly with the single GPU refactor. The process_batch function is now being used. The weight normalization is improved (rel_tol and zero/invalid check). The inertia_check field is not used anymore.
However, I do not feel confident implementing a massive unifying refactor in this PR that is originally dedicated to introducing MG Batched KMeans. Would it be fine to leave things as is for now? We could come back to it in a dedicated follow-up PR.
There was a problem hiding this comment.
If we can get this and the refactor into 26.06, its ok to break it down. We'll need to track the MG refactor in an issue and that would be a high priority.
|
Hi @viclafargue @tarang-jain , thanks for the excellent work on this PR! We have a large-scale clustering pipeline that requires multi-GPU batched Could you share a rough timeline for when this might be ready to merge? We'd be happy to be among the first adopters — once this is in a testable Thanks again for pushing this forward! |
|
Hi @zhilizju, This should be merged by early next week, but just know that we are in the process of doing large-scale benchmarks, so those will likely not be complete before merging (this mileage might vary). This will be part of our 26.06 release, which will be officially released in early June. Will you be integrating this through c++ or another one of the language wrappers? |
@tarang-jain it's usually safe to assume the data is randomly scattered. Kmeans++ is also expensive enough, especially with larger k, that I think doing the initialization on a single partition is going to be a good trade off most of the time. I'd even go as far as to say that we could offer an option for this in the mnmg version, and default it to preferring a single shard. Think of it like taking a simple random sample of a simple random sample (aka a random sample from the entire training set). |
|
yeah I think that makes sense -- random sample of a random sample. |
Thanks @cjnolet! Great to hear it's targeting early next week. We're already running the RAFT comms + Ray path (via cuML 25.10 _fit(multigpu=True) with NCCL across hundreds of We'll primarily use the Python wrapper. A couple of questions:
|
c142245 to
41c66b8
Compare
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
cpp/tests/cluster/kmeans.cu (2)
552-576:⚠️ Potential issue | 🟠 Major | ⚡ Quick winThese regressions still bypass the new host/batched path.
fitKMeansPlusPlus()andrunZeroCost()both callfit()withd_X->view(), so CI is only exercising the existing device-resident overload here. That leavesn_init > 1and the zero-cost edge case unvalidated for the host-streaming implementation this PR adds.Also applies to: 617-624
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tests/cluster/kmeans.cu` around lines 552 - 576, Tests are currently calling the device-resident overload by passing raft::make_const_mdspan(d_X->view()) from fitKMeansPlusPlus() (and similarly in runZeroCost()), so CI never exercises the new host/batched fit path; change those calls to pass a host-resident/batched view instead (e.g., create or reuse a host matrix/mdspan for X and call cuvs::cluster::kmeans::fit with raft::make_const_mdspan(h_X->view()) or the library’s host-mdspan helper) so the host-streaming overload is invoked (apply same change for the similar call at lines 617-624) and run with n_init>1 and the zero-cost scenario to validate the new path.
524-549:⚠️ Potential issue | 🟠 Major | ⚡ Quick winDon't assert that larger
init_sizemust improve inertia.For a single
KMeans++trial (n_init == 1), increasing the candidate set from the default subsample ton_samplesis not a monotonic guarantee.ASSERT_LE(inertia_full, inertia_default * (1 + rel))can fail even when the implementation is correct, so this introduces avoidable test flakiness.Possible adjustment
- // Full-dataset seeding has at least as much information as the subsample - // default, so the converged inertia should not be worse. - ASSERT_LE(inertia_full, inertia_default * (T(1) + rel)); + // `init_size = 0` should resolve to the documented default. + // Avoid asserting against `inertia_full`: a larger candidate set is not a + // monotonic guarantee for a single KMeans++ run.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tests/cluster/kmeans.cu` around lines 524 - 549, The test runInitSizeCompare asserts that inertia_full <= inertia_default*(1+rel), which is flaky because a single KMeans++ trial (n_init==1) does not guarantee better inertia with larger init_size; remove that ASSERT_LE(inertia_full, inertia_default * (T(1) + rel)) or make it conditional on a deterministic multi-init case (e.g., only check when n_init>1), keeping the other assertions on finiteness and positivity of inertia_default, inertia_explicit, and inertia_full unchanged; locate the check in runInitSizeCompare where inertia_full and inertia_default are compared and remove or guard that comparison.
🧹 Nitpick comments (1)
cpp/tests/cluster/kmeans.cu (1)
362-365: ⚡ Quick winWeighted batched coverage only hits the all-ones case.
Every weighted batched test fills
kmeans_weight_mode::uniform, so the suite still doesn't exercise zero-sum or sparse-weight inputs. Given the recent weight-handling changes, I'd add at least one host-data case with all-zero weights and one mixed zero/non-zero case to lock down that behavior.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tests/cluster/kmeans.cu` around lines 362 - 365, The weighted-batched tests currently only use fill_kmeans_test_weights(..., kmeans_weight_mode::uniform) so they never exercise zero-sum or sparse weights; update the test generation in the block guarded by testparams.weighted to add at least two additional weight scenarios: one where you create host-side all-zero weights and one mixed zero/non-zero (sparse) weights, transfer them into d_sample_weight (same device vector used now) and run the same batched KMeans paths; use or extend fill_kmeans_test_weights (or a new helper) to produce kmeans_weight_mode variants (e.g., zero/sparse) and ensure the test loop iterates over these modes so d_sample_weight, the KMeans call, and any assertions are exercised for zero-sum and mixed-weight cases.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@cpp/tests/cluster/kmeans.cu`:
- Around line 552-576: Tests are currently calling the device-resident overload
by passing raft::make_const_mdspan(d_X->view()) from fitKMeansPlusPlus() (and
similarly in runZeroCost()), so CI never exercises the new host/batched fit
path; change those calls to pass a host-resident/batched view instead (e.g.,
create or reuse a host matrix/mdspan for X and call cuvs::cluster::kmeans::fit
with raft::make_const_mdspan(h_X->view()) or the library’s host-mdspan helper)
so the host-streaming overload is invoked (apply same change for the similar
call at lines 617-624) and run with n_init>1 and the zero-cost scenario to
validate the new path.
- Around line 524-549: The test runInitSizeCompare asserts that inertia_full <=
inertia_default*(1+rel), which is flaky because a single KMeans++ trial
(n_init==1) does not guarantee better inertia with larger init_size; remove that
ASSERT_LE(inertia_full, inertia_default * (T(1) + rel)) or make it conditional
on a deterministic multi-init case (e.g., only check when n_init>1), keeping the
other assertions on finiteness and positivity of inertia_default,
inertia_explicit, and inertia_full unchanged; locate the check in
runInitSizeCompare where inertia_full and inertia_default are compared and
remove or guard that comparison.
---
Nitpick comments:
In `@cpp/tests/cluster/kmeans.cu`:
- Around line 362-365: The weighted-batched tests currently only use
fill_kmeans_test_weights(..., kmeans_weight_mode::uniform) so they never
exercise zero-sum or sparse weights; update the test generation in the block
guarded by testparams.weighted to add at least two additional weight scenarios:
one where you create host-side all-zero weights and one mixed zero/non-zero
(sparse) weights, transfer them into d_sample_weight (same device vector used
now) and run the same batched KMeans paths; use or extend
fill_kmeans_test_weights (or a new helper) to produce kmeans_weight_mode
variants (e.g., zero/sparse) and ensure the test loop iterates over these modes
so d_sample_weight, the KMeans call, and any assertions are exercised for
zero-sum and mixed-weight cases.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: a5efc988-95ff-49dd-ab2d-5c184e918f55
📒 Files selected for processing (4)
cpp/src/cluster/detail/kmeans_mg_batched.cuhcpp/tests/cluster/kmeans.cucpp/tests/cluster/kmeans_mg_batched.cucpp/tests/cluster/kmeans_test_blobs.cuh
✅ Files skipped from review due to trivial changes (1)
- cpp/tests/cluster/kmeans_test_blobs.cuh
🚧 Files skipped from review as they are similar to previous changes (2)
- cpp/tests/cluster/kmeans_mg_batched.cu
- cpp/src/cluster/detail/kmeans_mg_batched.cuh
|
Hi @zhilizju,
The cuVS nightly package will be available following the merge. However, it will only offer a C++ interface.
The plan so far would be to have the Python wrapper be placed in cuML in similar fashion to the regular MG KMeans. |
|
/ok to test 41c66b8 |
|
/ok to test f664c2c |
jinsolp
left a comment
There was a problem hiding this comment.
Thanks @viclafargue ! No major concerns from my side
| if (rank == 0) { | ||
| cuvs::cluster::kmeans::detail::init_centroids_from_host_sample<T, IdxT>( | ||
| dev_res, iter_params, streaming_batch_size, X_local, rank_centroids.view(), workspace); |
There was a problem hiding this comment.
Can we document that we require a well-shuffled dataset to run MG kmeans? Otherwise this will end up doing bad init because is only uses rank0 local data.
There was a problem hiding this comment.
Definitely should go one step further here and add an option to switch between "use kmeans++ globally and use kmeans++ locally".
|
The issue of the bias introduced by single rank sampling is now solved.
This keeps the implementation simple and memory bounded: the full input remains on host, and only the initialization sample is materialized on device. After the global sample is built, there is no additional multi-rank communication during KMeans++ initialization. Using true multi-GPU KMeans++ for very large
Given the added review surface, I think the current sampled-root initialization is the right approach for this PR. True distributed KMeans++ initialization can be tracked separately if we see workloads where |
|
/ok to test a14a6bc |
tarang-jain
left a comment
There was a problem hiding this comment.
this idea of global samples seems sound to me. I think the underlying assumption of contiguous sample IDs across ranks holds. We should definitely benchmark such an initialization since the creation of the init sample is done on host? one sample at a time?
| { | ||
| IndexT default_init_size = std::min(static_cast<IndexT>(std::int64_t{3} * n_clusters), global_n); | ||
| IndexT init_sample_size = params.init_size > 0 | ||
| ? std::min(static_cast<IndexT>(params.init_size), global_n) |
There was a problem hiding this comment.
Question: the same init_size will be available to all the ranks, right?
| std::unordered_set<IndexT> selected; | ||
| selected.reserve(static_cast<std::size_t>(sample_size)); | ||
|
|
||
| for (IndexT j = n_samples - sample_size; j < n_samples; ++j) { |
There was a problem hiding this comment.
why not do this on device? I see below you are copying the sampled IDs to device anyway on rank 0.
|
In the kmeanspp init, I think there is one more assumption -- which is that the init sample fits on rank 0. We should document this, otherwise user might get surprise OOM (for example an edge case can be when rank 0 has less available GPU RAM). |
Closes #1989.
Adds multi-GPU support to KMeans fit for host-resident data, with two modes:
device_resources_snmg.Both modes share the same core Lloyd's loop, batched streaming of host data, NCCL/comms allreduce of centroid sums and counts, and synchronized convergence. Supports sample weights, n_init best-of-N restarts, KMeansPlusPlus initialization, and float/double. Falls back to single-GPU when neither multi-GPU resources nor comms are present.