Skip to content

[FEA] Add Batching to KMeans#1886

Open
tarang-jain wants to merge 133 commits intorapidsai:release/26.04from
tarang-jain:batched-kmeans
Open

[FEA] Add Batching to KMeans#1886
tarang-jain wants to merge 133 commits intorapidsai:release/26.04from
tarang-jain:batched-kmeans

Conversation

@tarang-jain
Copy link
Contributor

@tarang-jain tarang-jain commented Mar 6, 2026

Merge after #1880

@tarang-jain tarang-jain requested review from a team as code owners March 12, 2026 20:42
@tarang-jain tarang-jain requested a review from bdice March 12, 2026 20:42
@tarang-jain tarang-jain removed request for a team March 12, 2026 20:45
@tarang-jain tarang-jain removed the request for review from bdice March 13, 2026 20:56
Copy link
Member

@aamijar aamijar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the new feature Tarang! Adding my review here:

IndexT n_clusters,
raft::device_matrix_view<DataT, IndexT, raft::row_major> centroid_sums,
raft::device_vector_view<DataT, IndexT> weight_per_cluster,
rmm::device_uvector<char>& workspace)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use raft::device_vector instead of rmm::device_uvector?

rmm::device_uvector<int> d_labels;
rmm::device_uvector<int> d_labels_ref;
rmm::device_uvector<T> d_centroids;
rmm::device_uvector<T> d_centroids_ref;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for these.

if not isinstance(X, np.ndarray):
X = np.asarray(X)
if not X.flags['C_CONTIGUOUS']:
X = np.ascontiguousarray(X)
Copy link
Member

@aamijar aamijar Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we inform the user that their dataset will be copied since it's not in a compatible format in both of these two if cases?


assert np.allclose(
centroids_regular, centroids_batched, rtol=1e-3, atol=1e-3
), f"max diff: {np.max(np.abs(centroids_regular - centroids_batched))}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we checking ARI score in any of the existing pytests? Or maybe on the cuml side?

Copy link
Contributor

@jinsolp jinsolp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding some questions and a suggestion!

@pytest.mark.parametrize("n_rows", [1000, 5000])
@pytest.mark.parametrize("n_cols", [10, 100])
@pytest.mark.parametrize("n_clusters", [8, 16])
@pytest.mark.parametrize("batch_size", [0, 100, 500])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test that isn't a divisor of n_rows? Super random number like 237 or something maybe?

Comment on lines +313 to +318
cuvs::cluster::kmeans::detail::computeClusterCost(handle,
minClusterAndDistance.view(),
workspace,
clustering_cost.view(),
raft::value_op{},
raft::add_op{});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are the stale values in minClusterAndDistance.view() handled properly for the last batch?

Comment on lines +376 to +383
if (n_init > 1) {
inertia[0] = best_inertia;
n_iter[0] = best_n_iter;
RAFT_LOG_DEBUG("KMeans batched: Best of %d runs: inertia=%f, n_iter=%d",
n_init,
static_cast<double>(best_inertia),
best_n_iter);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we have to copy the best centroids to the centroids output here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cpp feature request New feature or request non-breaking Introduces a non-breaking change

Projects

Development

Successfully merging this pull request may close these issues.

6 participants