[FEA] Add Batching to KMeans#1886
[FEA] Add Batching to KMeans#1886tarang-jain wants to merge 133 commits intorapidsai:release/26.04from
Conversation
…into minibatch-kmeans
…into minibatch-kmeans
…into minibatch-kmeans
624e2fb to
552f736
Compare
aamijar
left a comment
There was a problem hiding this comment.
Thanks for the new feature Tarang! Adding my review here:
| IndexT n_clusters, | ||
| raft::device_matrix_view<DataT, IndexT, raft::row_major> centroid_sums, | ||
| raft::device_vector_view<DataT, IndexT> weight_per_cluster, | ||
| rmm::device_uvector<char>& workspace) |
There was a problem hiding this comment.
Can we use raft::device_vector instead of rmm::device_uvector?
| rmm::device_uvector<int> d_labels; | ||
| rmm::device_uvector<int> d_labels_ref; | ||
| rmm::device_uvector<T> d_centroids; | ||
| rmm::device_uvector<T> d_centroids_ref; |
| if not isinstance(X, np.ndarray): | ||
| X = np.asarray(X) | ||
| if not X.flags['C_CONTIGUOUS']: | ||
| X = np.ascontiguousarray(X) |
There was a problem hiding this comment.
Should we inform the user that their dataset will be copied since it's not in a compatible format in both of these two if cases?
|
|
||
| assert np.allclose( | ||
| centroids_regular, centroids_batched, rtol=1e-3, atol=1e-3 | ||
| ), f"max diff: {np.max(np.abs(centroids_regular - centroids_batched))}" |
There was a problem hiding this comment.
Are we checking ARI score in any of the existing pytests? Or maybe on the cuml side?
jinsolp
left a comment
There was a problem hiding this comment.
Adding some questions and a suggestion!
| @pytest.mark.parametrize("n_rows", [1000, 5000]) | ||
| @pytest.mark.parametrize("n_cols", [10, 100]) | ||
| @pytest.mark.parametrize("n_clusters", [8, 16]) | ||
| @pytest.mark.parametrize("batch_size", [0, 100, 500]) |
There was a problem hiding this comment.
Can we add a test that isn't a divisor of n_rows? Super random number like 237 or something maybe?
| cuvs::cluster::kmeans::detail::computeClusterCost(handle, | ||
| minClusterAndDistance.view(), | ||
| workspace, | ||
| clustering_cost.view(), | ||
| raft::value_op{}, | ||
| raft::add_op{}); |
There was a problem hiding this comment.
are the stale values in minClusterAndDistance.view() handled properly for the last batch?
| if (n_init > 1) { | ||
| inertia[0] = best_inertia; | ||
| n_iter[0] = best_n_iter; | ||
| RAFT_LOG_DEBUG("KMeans batched: Best of %d runs: inertia=%f, n_iter=%d", | ||
| n_init, | ||
| static_cast<double>(best_inertia), | ||
| best_n_iter); | ||
| } |
There was a problem hiding this comment.
don't we have to copy the best centroids to the centroids output here?
Merge after #1880