Skip to content

Multi-GPU Batched KMeans#2017

Open
viclafargue wants to merge 85 commits into
rapidsai:mainfrom
viclafargue:mg-batched-kmeans
Open

Multi-GPU Batched KMeans#2017
viclafargue wants to merge 85 commits into
rapidsai:mainfrom
viclafargue:mg-batched-kmeans

Conversation

@viclafargue
Copy link
Copy Markdown
Contributor

Closes #1989.

Adds multi-GPU support to KMeans fit for host-resident data, with two modes:

  • OpenMP (cuVS SNMG): A single process drives all local GPUs via OMP threads and raw NCCL. Activated automatically when the handle is a device_resources_snmg.
  • RAFT comms (Ray / Dask / MPI): Each rank is a separate process that calls fit with its own data shard and an initialized RAFT communicator. Coordination uses the RAFT comms.

Both modes share the same core Lloyd's loop, batched streaming of host data, NCCL/comms allreduce of centroid sums and counts, and synchronized convergence. Supports sample weights, n_init best-of-N restarts, KMeansPlusPlus initialization, and float/double. Falls back to single-GPU when neither multi-GPU resources nor comms are present.

@viclafargue viclafargue self-assigned this Apr 13, 2026
@viclafargue viclafargue requested review from a team as code owners April 13, 2026 14:34
@viclafargue viclafargue added improvement Improves an existing functionality non-breaking Introduces a non-breaking change labels Apr 13, 2026
@viclafargue
Copy link
Copy Markdown
Contributor Author

Here are some instructions to test the Multi-GPU Batched KMeans API with RAFT comms (to be used with Ray/Dask) :

RAFT comms (Ray/Dask) demo code
#include <cuvs/cluster/kmeans.hpp>

#include <raft/comms/std_comms.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resource/comms.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>

#include <cuda_runtime.h>
#include <mpi.h>
#include <nccl.h>

#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <numeric>
#include <random>
#include <vector>

#define CHECK_CUDA(call)                                                 \
  do {                                                                   \
    cudaError_t e = (call);                                              \
    if (e != cudaSuccess) {                                              \
      std::fprintf(stderr, "CUDA error %s @ %s:%d\n",                   \
                   cudaGetErrorString(e), __FILE__, __LINE__);           \
      MPI_Abort(MPI_COMM_WORLD, 1);                                      \
    }                                                                    \
  } while (0)

#define CHECK_NCCL(call)                                                 \
  do {                                                                   \
    ncclResult_t r = (call);                                             \
    if (r != ncclSuccess) {                                              \
      std::fprintf(stderr, "NCCL error %s @ %s:%d\n",                   \
                   ncclGetErrorString(r), __FILE__, __LINE__);           \
      MPI_Abort(MPI_COMM_WORLD, 1);                                      \
    }                                                                    \
  } while (0)

int main(int argc, char** argv)
{
  MPI_Init(&argc, &argv);

  int rank, num_ranks;
  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  MPI_Comm_size(MPI_COMM_WORLD, &num_ranks);

  CHECK_CUDA(cudaSetDevice(rank));

  ncclUniqueId nccl_id;
  if (rank == 0) CHECK_NCCL(ncclGetUniqueId(&nccl_id));
  MPI_Bcast(&nccl_id, sizeof(nccl_id), MPI_BYTE, 0, MPI_COMM_WORLD);

  ncclComm_t nccl_comm;
  CHECK_NCCL(ncclCommInitRank(&nccl_comm, num_ranks, nccl_id, rank));

  raft::resources handle;
  raft::comms::build_comms_nccl_only(&handle, nccl_comm, num_ranks, rank);

  // --- Demo parameters ---
  constexpr int64_t n_samples       = 100'000;
  constexpr int64_t n_features      = 32;
  constexpr int     n_clusters      = 10;
  constexpr int64_t streaming_batch = 10'000;
  constexpr float   cluster_spread  = 1.0f;
  constexpr float   center_range    = 30.0f;

  if (rank == 0) {
    std::printf("=== Multi-GPU KMeans Demo (%d ranks) ===\n", num_ranks);
    std::printf("Samples: %ld | Features: %ld | k: %d | batch: %ld\n\n",
                long(n_samples), long(n_features), n_clusters, long(streaming_batch));
  }

  // Generate synthetic blobs with well-separated cluster centers
  std::vector<float> h_data(n_samples * n_features);
  std::vector<int>   h_true_labels(n_samples);
  std::vector<float> cluster_centers(n_clusters * n_features);
  {
    std::mt19937 gen(12345);
    std::uniform_real_distribution<float> center_dist(-center_range, center_range);
    std::normal_distribution<float> noise(0.0f, cluster_spread);

    for (int c = 0; c < n_clusters; ++c)
      for (int d = 0; d < n_features; ++d)
        cluster_centers[c * n_features + d] = center_dist(gen);

    for (int64_t i = 0; i < n_samples; ++i) {
      int label = static_cast<int>(i % n_clusters);
      h_true_labels[i] = label;
      for (int d = 0; d < n_features; ++d)
        h_data[i * n_features + d] = cluster_centers[label * n_features + d] + noise(gen);
    }

    // Shuffle so labels aren't just sequential runs
    std::vector<int64_t> perm(n_samples);
    std::iota(perm.begin(), perm.end(), 0);
    std::shuffle(perm.begin(), perm.end(), gen);

    std::vector<float> tmp_data(h_data);
    std::vector<int>   tmp_labels(h_true_labels);
    for (int64_t i = 0; i < n_samples; ++i) {
      std::memcpy(h_data.data() + i * n_features,
                  tmp_data.data() + perm[i] * n_features,
                  n_features * sizeof(float));
      h_true_labels[i] = tmp_labels[perm[i]];
    }
  }

  int64_t base    = n_samples / num_ranks;
  int64_t rem     = n_samples % num_ranks;
  int64_t offset  = rank * base + std::min<int64_t>(rank, rem);
  int64_t n_local = base + (rank < rem ? 1 : 0);

  std::printf("[rank %d / GPU %d]  rows [%ld .. %ld)  (%ld samples)\n",
              rank, rank, long(offset), long(offset + n_local), long(n_local));

  auto X_local = raft::make_host_matrix_view<const float, int64_t>(
    h_data.data() + offset * n_features, n_local, n_features);

  auto d_centroids = raft::make_device_matrix<float, int64_t>(handle, n_clusters, n_features);

  cuvs::cluster::kmeans::params params;
  params.n_clusters           = n_clusters;
  params.max_iter             = 50;
  params.tol                  = 1e-4;
  params.init                 = cuvs::cluster::kmeans::params::KMeansPlusPlus;
  params.rng_state.seed       = 42;
  params.inertia_check        = true;
  params.streaming_batch_size = streaming_batch;

  float   inertia = 0.0f;
  int64_t n_iter  = 0;

  cuvs::cluster::kmeans::fit(handle,
                             params,
                             X_local,
                             std::nullopt,
                             d_centroids.view(),
                             raft::make_host_scalar_view(&inertia),
                             raft::make_host_scalar_view(&n_iter));

  auto stream = raft::resource::get_cuda_stream(handle);
  CHECK_CUDA(cudaStreamSynchronize(stream));

  if (rank == 0) {
    // --- Predict labels on the full dataset (on rank 0) ---
    auto d_X = raft::make_device_matrix<float, int64_t>(handle, n_samples, n_features);
    CHECK_CUDA(cudaMemcpy(d_X.data_handle(), h_data.data(),
                          sizeof(float) * n_samples * n_features, cudaMemcpyHostToDevice));

    auto d_labels = raft::make_device_vector<int64_t, int64_t>(handle, n_samples);
    float predict_inertia = 0.0f;

    cuvs::cluster::kmeans::predict(
      handle, params,
      raft::make_device_matrix_view<const float, int64_t>(d_X.data_handle(), n_samples, n_features),
      std::nullopt,
      raft::make_device_matrix_view<const float, int64_t>(
        d_centroids.data_handle(), n_clusters, n_features),
      d_labels.view(),
      false,
      raft::make_host_scalar_view(&predict_inertia));
    CHECK_CUDA(cudaStreamSynchronize(stream));

    std::vector<int64_t> h_labels(n_samples);
    CHECK_CUDA(cudaMemcpy(h_labels.data(), d_labels.data_handle(),
                          sizeof(int64_t) * n_samples, cudaMemcpyDeviceToHost));

    // --- Quality: permutation-invariant accuracy via majority voting ---
    // For each predicted cluster, find which true label appears most often.
    std::vector<std::vector<int64_t>> confusion(n_clusters, std::vector<int64_t>(n_clusters, 0));
    for (int64_t i = 0; i < n_samples; ++i)
      confusion[h_labels[i]][h_true_labels[i]]++;

    // Greedy matching: assign each predicted cluster to its dominant true label
    std::vector<int> pred_to_true(n_clusters, -1);
    std::vector<bool> true_taken(n_clusters, false);
    for (int round = 0; round < n_clusters; ++round) {
      int64_t best_count = -1;
      int best_pred = -1, best_true = -1;
      for (int p = 0; p < n_clusters; ++p) {
        if (pred_to_true[p] >= 0) continue;
        for (int t = 0; t < n_clusters; ++t) {
          if (true_taken[t]) continue;
          if (confusion[p][t] > best_count) {
            best_count = confusion[p][t];
            best_pred = p;
            best_true = t;
          }
        }
      }
      pred_to_true[best_pred] = best_true;
      true_taken[best_true] = true;
    }

    int64_t correct = 0;
    std::vector<int64_t> cluster_sizes(n_clusters, 0);
    std::vector<int64_t> cluster_correct(n_clusters, 0);
    for (int64_t i = 0; i < n_samples; ++i) {
      int p = static_cast<int>(h_labels[i]);
      cluster_sizes[p]++;
      if (h_true_labels[i] == pred_to_true[p]) {
        ++correct;
        ++cluster_correct[p];
      }
    }
    double accuracy = 100.0 * correct / n_samples;

    // --- Compute centroid-to-true-center distances ---
    std::vector<float> h_centroids(n_clusters * n_features);
    CHECK_CUDA(cudaMemcpy(h_centroids.data(), d_centroids.data_handle(),
                          sizeof(float) * n_clusters * n_features, cudaMemcpyDeviceToHost));

    std::printf("\n============ Multi-GPU KMeans Results ============\n");
    std::printf("  Ranks             : %d\n", num_ranks);
    std::printf("  Total samples     : %ld\n", long(n_samples));
    std::printf("  Features          : %ld\n", long(n_features));
    std::printf("  Clusters (k)      : %d\n", n_clusters);
    std::printf("  Streaming batch   : %ld\n", long(streaming_batch));
    std::printf("  Lloyd iterations  : %ld\n", long(n_iter));
    std::printf("  Final inertia     : %.6f\n", double(inertia));
    std::printf("  Predict inertia   : %.6f\n", double(predict_inertia));
    std::printf("\n  --- Clustering Quality ---\n");
    std::printf("  Overall accuracy  : %.2f%% (%ld / %ld)\n",
                accuracy, long(correct), long(n_samples));

    std::printf("\n  Per-cluster breakdown:\n");
    std::printf("  %6s  %10s  %10s  %8s  %12s\n",
                "Pred", "TrueLabel", "Size", "Acc%", "CentroidErr");
    for (int p = 0; p < n_clusters; ++p) {
      int t = pred_to_true[p];
      double pct = cluster_sizes[p] > 0
                     ? 100.0 * cluster_correct[p] / cluster_sizes[p]
                     : 0.0;

      // L2 distance between learned centroid and ground truth center
      double dist2 = 0.0;
      for (int d = 0; d < n_features; ++d) {
        double diff = h_centroids[p * n_features + d] - cluster_centers[t * n_features + d];
        dist2 += diff * diff;
      }
      std::printf("  %6d  %10d  %10ld  %7.2f%%  %12.4f\n",
                  p, t, long(cluster_sizes[p]), pct, std::sqrt(dist2));
    }

    std::printf("\n  Expected accuracy for well-separated blobs: >99%%\n");
    if (accuracy >= 99.0)
      std::printf("  PASS: Clustering quality is high.\n");
    else if (accuracy >= 90.0)
      std::printf("  WARN: Clustering quality is acceptable but not ideal.\n");
    else
      std::printf("  FAIL: Clustering quality is poor!\n");

    std::printf("==================================================\n");
  }

  CHECK_NCCL(ncclCommDestroy(nccl_comm));
  MPI_Finalize();
  return 0;
}
Compilation command
nvcc -std=c++17 -x cu --extended-lambda -arch=native       \
 -I$CONDA_PREFIX/include/rapids                            \
 -I$CONDA_PREFIX/include                                   \
 demo_mg_kmeans_raft_comms.cu                              \
 -L$CONDA_PREFIX/lib -lcuvs -lnccl -lrmm -lmpi             \
 -lucxx -lucp -lucs                                       \
 -Xlinker=-rpath,$CONDA_PREFIX/lib                         \
 -o demo_mg_kmeans
Launch command

mpirun -np 2 ./demo_mg_kmeans

@viclafargue viclafargue requested a review from tarang-jain April 13, 2026 14:42
Comment thread cpp/src/cluster/detail/kmeans_mg_batched.cuh Outdated
Comment thread cpp/src/cluster/detail/kmeans_mg_batched.cuh
Comment thread cpp/src/cluster/detail/kmeans_mg_batched.cuh
Comment thread cpp/src/cluster/detail/kmeans_mg_batched.cuh Outdated
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get rid of this file entirely and combine with regular mg kmeans (just as we are doing in PR #2015)? Is that possible?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, MNMG should be able to reuse the snmg_fit function (for a single worker) as is, right? Except that the nccl reduce macro will be replaced by something like comms.allreduce()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored the code to work flawlessly with the single GPU refactor. The process_batch function is now being used. The weight normalization is improved (rel_tol and zero/invalid check). The inertia_check field is not used anymore.

However, I do not feel confident implementing a massive unifying refactor in this PR that is originally dedicated to introducing MG Batched KMeans. Would it be fine to leave things as is for now? We could come back to it in a dedicated follow-up PR.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can get this and the refactor into 26.06, its ok to break it down. We'll need to track the MG refactor in an issue and that would be a high priority.

Comment thread cpp/src/cluster/detail/kmeans_mg_batched.cuh Outdated
Comment thread cpp/src/cluster/detail/kmeans_mg_batched.cuh Outdated
@zhilizju
Copy link
Copy Markdown

zhilizju commented May 8, 2026

Hi @viclafargue @tarang-jain , thanks for the excellent work on this PR!

We have a large-scale clustering pipeline that requires multi-GPU batched
KMeans — we're running on hundreds of nodes and need this capability to
handle our dataset scale. We're eagerly waiting to adopt this feature
once it lands.

Could you share a rough timeline for when this might be ready to merge?
Do you think the remaining items (rebase on #2015, the KMeans++ init
concern, etc.) could be wrapped up by this weekend, or would it need
another week or so? We'd like to plan our integration work accordingly.

We'd be happy to be among the first adopters — once this is in a testable
state, we can run it against our production workload and provide detailed feedback on correctness,
scalability, and API ergonomics. Let us know if there's a branch or
pre-release build we can start testing against.

Thanks again for pushing this forward!

@cjnolet
Copy link
Copy Markdown
Member

cjnolet commented May 8, 2026

Hi @zhilizju,

This should be merged by early next week, but just know that we are in the process of doing large-scale benchmarks, so those will likely not be complete before merging (this mileage might vary).

This will be part of our 26.06 release, which will be officially released in early June.

Will you be integrating this through c++ or another one of the language wrappers?

@cjnolet
Copy link
Copy Markdown
Member

cjnolet commented May 8, 2026

The shard might not be representative of the whole dataset

@tarang-jain it's usually safe to assume the data is randomly scattered. Kmeans++ is also expensive enough, especially with larger k, that I think doing the initialization on a single partition is going to be a good trade off most of the time. I'd even go as far as to say that we could offer an option for this in the mnmg version, and default it to preferring a single shard. Think of it like taking a simple random sample of a simple random sample (aka a random sample from the entire training set).

@tarang-jain
Copy link
Copy Markdown
Contributor

yeah I think that makes sense -- random sample of a random sample.

Comment thread cpp/tests/cluster/kmeans_mg_batched.cu Outdated
Comment thread cpp/tests/cluster/kmeans_mg_batched.cu
@zhilizju
Copy link
Copy Markdown

zhilizju commented May 8, 2026

Hi @zhilizju,

This should be merged by early next week, but just know that we are in the process of doing large-scale benchmarks, so those will likely not be complete before merging (this mileage might vary).

This will be part of our 26.06 release, which will be officially released in early June.

Will you be integrating this through c++ or another one of the language wrappers?

Thanks @cjnolet! Great to hear it's targeting early next week.

We're already running the RAFT comms + Ray path (via cuML 25.10 _fit(multigpu=True) with NCCL across hundreds of
nodes), so we're very familiar with this architecture. Looking forward to the cuVS native implementation for better performance with host-resident batched streaming.

We'll primarily use the Python wrapper. A couple of questions:

  1. Once merged, will the multi-GPU batched KMeans be accessible from Python with the RAFT comms path? We'd like
    to integrate it into our existing Ray + NCCL setup.
  2. Is there a nightly conda package that will pick this up automatically after merge, or would we need to build from
    source?

@viclafargue viclafargue force-pushed the mg-batched-kmeans branch from c142245 to 41c66b8 Compare May 8, 2026 09:14
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
cpp/tests/cluster/kmeans.cu (2)

552-576: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

These regressions still bypass the new host/batched path.

fitKMeansPlusPlus() and runZeroCost() both call fit() with d_X->view(), so CI is only exercising the existing device-resident overload here. That leaves n_init > 1 and the zero-cost edge case unvalidated for the host-streaming implementation this PR adds.

Also applies to: 617-624

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/tests/cluster/kmeans.cu` around lines 552 - 576, Tests are currently
calling the device-resident overload by passing
raft::make_const_mdspan(d_X->view()) from fitKMeansPlusPlus() (and similarly in
runZeroCost()), so CI never exercises the new host/batched fit path; change
those calls to pass a host-resident/batched view instead (e.g., create or reuse
a host matrix/mdspan for X and call cuvs::cluster::kmeans::fit with
raft::make_const_mdspan(h_X->view()) or the library’s host-mdspan helper) so the
host-streaming overload is invoked (apply same change for the similar call at
lines 617-624) and run with n_init>1 and the zero-cost scenario to validate the
new path.

524-549: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Don't assert that larger init_size must improve inertia.

For a single KMeans++ trial (n_init == 1), increasing the candidate set from the default subsample to n_samples is not a monotonic guarantee. ASSERT_LE(inertia_full, inertia_default * (1 + rel)) can fail even when the implementation is correct, so this introduces avoidable test flakiness.

Possible adjustment
-    // Full-dataset seeding has at least as much information as the subsample
-    // default, so the converged inertia should not be worse.
-    ASSERT_LE(inertia_full, inertia_default * (T(1) + rel));
+    // `init_size = 0` should resolve to the documented default.
+    // Avoid asserting against `inertia_full`: a larger candidate set is not a
+    // monotonic guarantee for a single KMeans++ run.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/tests/cluster/kmeans.cu` around lines 524 - 549, The test
runInitSizeCompare asserts that inertia_full <= inertia_default*(1+rel), which
is flaky because a single KMeans++ trial (n_init==1) does not guarantee better
inertia with larger init_size; remove that ASSERT_LE(inertia_full,
inertia_default * (T(1) + rel)) or make it conditional on a deterministic
multi-init case (e.g., only check when n_init>1), keeping the other assertions
on finiteness and positivity of inertia_default, inertia_explicit, and
inertia_full unchanged; locate the check in runInitSizeCompare where
inertia_full and inertia_default are compared and remove or guard that
comparison.
🧹 Nitpick comments (1)
cpp/tests/cluster/kmeans.cu (1)

362-365: ⚡ Quick win

Weighted batched coverage only hits the all-ones case.

Every weighted batched test fills kmeans_weight_mode::uniform, so the suite still doesn't exercise zero-sum or sparse-weight inputs. Given the recent weight-handling changes, I'd add at least one host-data case with all-zero weights and one mixed zero/non-zero case to lock down that behavior.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/tests/cluster/kmeans.cu` around lines 362 - 365, The weighted-batched
tests currently only use fill_kmeans_test_weights(...,
kmeans_weight_mode::uniform) so they never exercise zero-sum or sparse weights;
update the test generation in the block guarded by testparams.weighted to add at
least two additional weight scenarios: one where you create host-side all-zero
weights and one mixed zero/non-zero (sparse) weights, transfer them into
d_sample_weight (same device vector used now) and run the same batched KMeans
paths; use or extend fill_kmeans_test_weights (or a new helper) to produce
kmeans_weight_mode variants (e.g., zero/sparse) and ensure the test loop
iterates over these modes so d_sample_weight, the KMeans call, and any
assertions are exercised for zero-sum and mixed-weight cases.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@cpp/tests/cluster/kmeans.cu`:
- Around line 552-576: Tests are currently calling the device-resident overload
by passing raft::make_const_mdspan(d_X->view()) from fitKMeansPlusPlus() (and
similarly in runZeroCost()), so CI never exercises the new host/batched fit
path; change those calls to pass a host-resident/batched view instead (e.g.,
create or reuse a host matrix/mdspan for X and call cuvs::cluster::kmeans::fit
with raft::make_const_mdspan(h_X->view()) or the library’s host-mdspan helper)
so the host-streaming overload is invoked (apply same change for the similar
call at lines 617-624) and run with n_init>1 and the zero-cost scenario to
validate the new path.
- Around line 524-549: The test runInitSizeCompare asserts that inertia_full <=
inertia_default*(1+rel), which is flaky because a single KMeans++ trial
(n_init==1) does not guarantee better inertia with larger init_size; remove that
ASSERT_LE(inertia_full, inertia_default * (T(1) + rel)) or make it conditional
on a deterministic multi-init case (e.g., only check when n_init>1), keeping the
other assertions on finiteness and positivity of inertia_default,
inertia_explicit, and inertia_full unchanged; locate the check in
runInitSizeCompare where inertia_full and inertia_default are compared and
remove or guard that comparison.

---

Nitpick comments:
In `@cpp/tests/cluster/kmeans.cu`:
- Around line 362-365: The weighted-batched tests currently only use
fill_kmeans_test_weights(..., kmeans_weight_mode::uniform) so they never
exercise zero-sum or sparse weights; update the test generation in the block
guarded by testparams.weighted to add at least two additional weight scenarios:
one where you create host-side all-zero weights and one mixed zero/non-zero
(sparse) weights, transfer them into d_sample_weight (same device vector used
now) and run the same batched KMeans paths; use or extend
fill_kmeans_test_weights (or a new helper) to produce kmeans_weight_mode
variants (e.g., zero/sparse) and ensure the test loop iterates over these modes
so d_sample_weight, the KMeans call, and any assertions are exercised for
zero-sum and mixed-weight cases.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: a5efc988-95ff-49dd-ab2d-5c184e918f55

📥 Commits

Reviewing files that changed from the base of the PR and between add9db1 and 41c66b8.

📒 Files selected for processing (4)
  • cpp/src/cluster/detail/kmeans_mg_batched.cuh
  • cpp/tests/cluster/kmeans.cu
  • cpp/tests/cluster/kmeans_mg_batched.cu
  • cpp/tests/cluster/kmeans_test_blobs.cuh
✅ Files skipped from review due to trivial changes (1)
  • cpp/tests/cluster/kmeans_test_blobs.cuh
🚧 Files skipped from review as they are similar to previous changes (2)
  • cpp/tests/cluster/kmeans_mg_batched.cu
  • cpp/src/cluster/detail/kmeans_mg_batched.cuh

@viclafargue
Copy link
Copy Markdown
Contributor Author

Hi @zhilizju,

Is there a nightly conda package that will pick this up automatically after merge, or would we need to build from source?

The cuVS nightly package will be available following the merge. However, it will only offer a C++ interface.

Once merged, will the multi-GPU batched KMeans be accessible from Python with the RAFT comms path? We'd like to integrate it into our existing Ray + NCCL setup.

The plan so far would be to have the Python wrapper be placed in cuML in similar fashion to the regular MG KMeans.

@viclafargue
Copy link
Copy Markdown
Contributor Author

/ok to test 41c66b8

@viclafargue
Copy link
Copy Markdown
Contributor Author

/ok to test f664c2c

@viclafargue viclafargue requested a review from tarang-jain May 8, 2026 17:01
Comment thread cpp/src/cluster/detail/kmeans_mg_batched.cuh Outdated
Copy link
Copy Markdown
Contributor

@jinsolp jinsolp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @viclafargue ! No major concerns from my side

Comment thread cpp/src/cluster/detail/kmeans_mg_batched.cuh Outdated
Comment on lines +370 to +372
if (rank == 0) {
cuvs::cluster::kmeans::detail::init_centroids_from_host_sample<T, IdxT>(
dev_res, iter_params, streaming_batch_size, X_local, rank_centroids.view(), workspace);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we document that we require a well-shuffled dataset to run MG kmeans? Otherwise this will end up doing bad init because is only uses rank0 local data.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely should go one step further here and add an option to switch between "use kmeans++ globally and use kmeans++ locally".

Comment thread cpp/src/cluster/detail/kmeans_mg_batched.cuh
@viclafargue
Copy link
Copy Markdown
Contributor Author

The issue of the bias introduced by single rank sampling is now solved.
Here is how the current implementation works :

  • Array : simple broadcast
  • Random : samples n_clusters unique rows globally
  • KMeansPlusPlus : samples init_sample_size unique rows globally, root runs local kmeansPlusPlus or initScalableKMeansPlusPlus on this sample data.

This keeps the implementation simple and memory bounded: the full input remains on host, and only the initialization sample is materialized on device. After the global sample is built, there is no additional multi-rank communication during KMeans++ initialization.

Using true multi-GPU KMeans++ for very large init_sample_size is possible, but it would be a larger follow-up. It would likely require:

  • An API/ABI-sensitive parameter or policy to opt into distributed initialization.
  • Large rRefactoring of MG KMeans++ implementation to accommodate for different comm system, distributed host-row sampling, handling of empty-rank cases and variable-size candidate gathers.

Given the added review surface, I think the current sampled-root initialization is the right approach for this PR. True distributed KMeans++ initialization can be tracked separately if we see workloads where init_sample_size becomes a bottleneck.

@viclafargue
Copy link
Copy Markdown
Contributor Author

/ok to test a14a6bc

Copy link
Copy Markdown
Contributor

@tarang-jain tarang-jain left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this idea of global samples seems sound to me. I think the underlying assumption of contiguous sample IDs across ranks holds. We should definitely benchmark such an initialization since the creation of the init sample is done on host? one sample at a time?

{
IndexT default_init_size = std::min(static_cast<IndexT>(std::int64_t{3} * n_clusters), global_n);
IndexT init_sample_size = params.init_size > 0
? std::min(static_cast<IndexT>(params.init_size), global_n)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: the same init_size will be available to all the ranks, right?

std::unordered_set<IndexT> selected;
selected.reserve(static_cast<std::size_t>(sample_size));

for (IndexT j = n_samples - sample_size; j < n_samples; ++j) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not do this on device? I see below you are copying the sampled IDs to device anyway on rank 0.

@tarang-jain
Copy link
Copy Markdown
Contributor

In the kmeanspp init, I think there is one more assumption -- which is that the init sample fits on rank 0. We should document this, otherwise user might get surprise OOM (for example an edge case can be when rank 0 has less available GPU RAM).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

improvement Improves an existing functionality non-breaking Introduces a non-breaking change

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

[FEA] Multi-node Multi-GPU Kmeans (C++) to support new out-of-core batching

5 participants