-
Notifications
You must be signed in to change notification settings - Fork 668
[Common][JAX] Add CUB TopK MaxPairs interface #2784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| #ifndef TRANSFORMER_ENGINE_CUB_H_ | ||
| #define TRANSFORMER_ENGINE_CUB_H_ | ||
|
|
||
| #include "transformer_engine.h" | ||
|
|
||
| #ifdef __cplusplus | ||
| extern "C" { | ||
| #endif | ||
|
|
||
| /*! \brief Compute the top-K largest (key, value) pairs using CUB. | ||
| * | ||
| * \param[in] stream CUDA stream used for the operation. | ||
| * \param[in] keys_in Input 1D keys tensor, shape (num_items,) | ||
| * \param[in] values_in Input 1D values tensor, shape (num_items,) | ||
| * \param[in,out] keys_out Output 1D keys tensor, shape (k,) | ||
| * \param[in,out] values_out Output 1D values tensor, shape (k,) | ||
| * \param[in,out] workspace Workspace tensor, shape (workspace_bytes,) | ||
| * \param[in] num_items Number of items in the input tensor | ||
| * \param[in] k Number of top-K largest values to return | ||
| * \param[in] workspace_bytes Workspace size in bytes | ||
| * | ||
| * Requirements: | ||
| * - Only supports float32, float16, bfloat16 keys and int32 values. | ||
| */ | ||
| void nvte_cub_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor values_in, | ||
| NVTETensor keys_out, NVTETensor values_out, NVTETensor workspace, | ||
| const int num_items, const int k, const size_t workspace_bytes); | ||
|
|
||
| #ifdef __cplusplus | ||
| } // extern "C" | ||
| #endif | ||
|
|
||
| #endif |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| #include <transformer_engine/cub.h> | ||
|
|
||
| #include <cub/device/device_topk.cuh> | ||
| #include <cuda/std/execution> | ||
|
|
||
| #include "../common.h" | ||
|
|
||
| void nvte_cub_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor values_in, | ||
| NVTETensor keys_out, NVTETensor values_out, NVTETensor workspace, int num_items, | ||
| int k, size_t workspace_bytes) { | ||
| NVTE_API_CALL(nvte_cub_topk); | ||
| using namespace transformer_engine; | ||
|
|
||
| const Tensor *keys_in_tensor = convertNVTETensorCheck(keys_in); | ||
| const Tensor *values_in_tensor = convertNVTETensorCheck(values_in); | ||
| Tensor *keys_out_tensor = convertNVTETensor(keys_out); | ||
| Tensor *values_out_tensor = convertNVTETensor(values_out); | ||
| Tensor *workspace_tensor = convertNVTETensor(workspace); | ||
| auto keys_in_dtype = keys_in_tensor->data.dtype; | ||
| auto values_in_dtype = values_in_tensor->data.dtype; | ||
|
|
||
| auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed, | ||
| cuda::execution::output_ordering::unsorted); | ||
| cuda::stream_ref stream_ref{stream}; | ||
| auto env = cuda::std::execution::env{stream_ref, requirements}; | ||
|
|
||
| #define DISPATCH_CUB_TOPK(KeyT, ValueT) \ | ||
| do { \ | ||
| KeyT *d_keys_in = reinterpret_cast<KeyT *>(keys_in_tensor->data.dptr); \ | ||
| KeyT *d_keys_out = reinterpret_cast<KeyT *>(keys_out_tensor->data.dptr); \ | ||
| ValueT *d_values_in = reinterpret_cast<ValueT *>(values_in_tensor->data.dptr); \ | ||
| ValueT *d_values_out = reinterpret_cast<ValueT *>(values_out_tensor->data.dptr); \ | ||
| void *d_workspace = reinterpret_cast<void *>(workspace_tensor->data.dptr); \ | ||
| cub::DeviceTopK::MaxPairs(d_workspace, workspace_bytes, d_keys_in, d_keys_out, d_values_in, \ | ||
| d_values_out, num_items, k, env); \ | ||
| } while (0); | ||
|
|
||
| if (keys_in_dtype == DType::kFloat32 && values_in_dtype == DType::kInt32) { | ||
| DISPATCH_CUB_TOPK(float, int); | ||
| } else if (keys_in_dtype == DType::kFloat16 && values_in_dtype == DType::kInt32) { | ||
| DISPATCH_CUB_TOPK(__half, int); | ||
| } else if (keys_in_dtype == DType::kBFloat16 && values_in_dtype == DType::kInt32) { | ||
| DISPATCH_CUB_TOPK(__nv_bfloat16, int); | ||
| } else { | ||
| NVTE_ERROR("Unsupported input key and value data types"); | ||
| } | ||
| #undef DISPATCH_CUB_TOPK | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
| """CUB custom ops""" | ||
|
|
||
| from typing import Tuple | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| from jax import dtypes, ffi | ||
|
|
||
| from .base import BasePrimitive, register_primitive | ||
|
|
||
| __all__ = ["CubTopkPrimitive"] | ||
|
|
||
|
|
||
| def get_cub_topk_workspace_bytes() -> int: | ||
| """ | ||
| Get the workspace size for CUB Topk | ||
| The safe way is calling the CUB kernel to query the workspace size. | ||
| For convenience, we use a heuristic value based on experiments. | ||
| 4 MiB is enough for N up to 5,000,000 and K up to 100,000. | ||
| """ | ||
| return 4 * 1024 * 1024 | ||
|
Comment on lines
+17
to
+24
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If a caller passes The correct approach is to call |
||
|
|
||
|
|
||
| class CubTopkPrimitive(BasePrimitive): | ||
| """ | ||
| CUB Topk Primitive | ||
| """ | ||
|
|
||
| name = "te_cub_topk_ffi" | ||
| multiple_results = True | ||
| impl_static_args = (2,) # k_value | ||
| inner_primitive = None | ||
| outer_primitive = None | ||
|
|
||
| @staticmethod | ||
| def abstract( | ||
| in_keys_aval, | ||
| in_values_aval, | ||
| *, | ||
| k_value, | ||
| ): | ||
| keys_dtype = dtypes.canonicalize_dtype(in_keys_aval.dtype) | ||
| values_dtype = dtypes.canonicalize_dtype(in_values_aval.dtype) | ||
| assert keys_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] | ||
| assert values_dtype == jnp.int32 | ||
|
|
||
| workspace_bytes = get_cub_topk_workspace_bytes() | ||
| out_keys_aval = jax.core.ShapedArray(shape=(k_value,), dtype=keys_dtype) | ||
| out_values_aval = jax.core.ShapedArray(shape=(k_value,), dtype=jnp.int32) | ||
| workspace_aval = jax.core.ShapedArray(shape=(workspace_bytes,), dtype=jnp.uint8) | ||
| return (out_keys_aval, out_values_aval, workspace_aval) | ||
|
|
||
| @staticmethod | ||
| def outer_abstract(*args, **kwargs): | ||
| out_keys_aval, out_values_aval, _workspace_aval = CubTopkPrimitive.abstract(*args, **kwargs) | ||
| return (out_keys_aval, out_values_aval) | ||
|
|
||
| @staticmethod | ||
| def lowering( | ||
| ctx, | ||
| in_keys, | ||
| in_values, | ||
| k_value, | ||
| ): | ||
| workspace_bytes = get_cub_topk_workspace_bytes() | ||
| return ffi.ffi_lowering( | ||
| CubTopkPrimitive.name, | ||
| )( | ||
| ctx, | ||
| in_keys, | ||
| in_values, | ||
| k_value=k_value, | ||
| workbuf_bytes=workspace_bytes, | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def impl( | ||
| in_keys, | ||
| in_values, | ||
| k_value, | ||
| ): | ||
| assert CubTopkPrimitive.inner_primitive is not None | ||
| out_keys, out_values, _workspace = CubTopkPrimitive.inner_primitive.bind( | ||
| in_keys, | ||
| in_values, | ||
| k_value=k_value, | ||
| ) | ||
| return (out_keys, out_values) | ||
|
|
||
|
|
||
| register_primitive(CubTopkPrimitive) | ||
|
Comment on lines
+27
to
+94
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When batching.primitive_batchers[outer_p] = cls.batcher # resolves to abstract method → returns NotImplemented
outer_p_lower.def_partition(partition=cls.partition, ...) # sameThis means:
Every other primitive in the codebase (e.g., in |
||
|
|
||
|
|
||
| def cub_topk( | ||
| x: jnp.ndarray, | ||
| k_value: int, | ||
| ) -> Tuple[jnp.ndarray, jnp.ndarray]: | ||
| """ | ||
| CUB Topk max pairs | ||
| """ | ||
| keys = x | ||
| values = jnp.arange(x.shape[0], dtype=jnp.int32) | ||
| out_keys, out_values = CubTopkPrimitive.outer_primitive.bind( | ||
| keys, | ||
| values, | ||
| k_value=k_value, | ||
| ) | ||
| return out_keys, out_values | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| #include "transformer_engine/cub.h" | ||
|
|
||
| #include "../extensions.h" | ||
| #include "xla/ffi/api/c_api.h" | ||
|
|
||
| namespace transformer_engine { | ||
| namespace jax { | ||
|
|
||
| Error_Type CubTopkFFI(cudaStream_t stream, Buffer_Type keys_in_buf, Buffer_Type values_in_buf, | ||
| Result_Type keys_out_buf, Result_Type values_out_buf, | ||
| Result_Type workspace_buf, int64_t k_value, int64_t workbuf_bytes) { | ||
| auto keys_in_dtype = convert_ffi_datatype_to_te_dtype(keys_in_buf.element_type()); | ||
| auto values_in_dtype = convert_ffi_datatype_to_te_dtype(values_in_buf.element_type()); | ||
| auto keys_out_dtype = convert_ffi_datatype_to_te_dtype(keys_out_buf->element_type()); | ||
| auto values_out_dtype = convert_ffi_datatype_to_te_dtype(values_out_buf->element_type()); | ||
| NVTE_CHECK(keys_in_dtype == keys_out_dtype, "Input and output keys must have the same datatype"); | ||
| NVTE_CHECK(values_in_dtype == values_out_dtype, | ||
| "Input and output values must have the same datatype"); | ||
| NVTE_CHECK(values_in_dtype == DType::kInt32, "CubTopkFFI() only supports int32 values for now"); | ||
|
|
||
| auto keys_in_shape = keys_in_buf.dimensions(); | ||
| auto values_in_shape = values_in_buf.dimensions(); | ||
| auto keys_out_shape = keys_out_buf->dimensions(); | ||
| auto values_out_shape = values_out_buf->dimensions(); | ||
| NVTE_CHECK(keys_in_shape.size() == 1, "Keys input must have 1 dimension"); | ||
| NVTE_CHECK(values_in_shape.size() == 1, "Values input must have 1 dimension"); | ||
| NVTE_CHECK(keys_out_shape.size() == 1, "Keys output must have 1 dimension"); | ||
| NVTE_CHECK(values_out_shape.size() == 1, "Values output must have 1 dimension"); | ||
| NVTE_CHECK(keys_in_shape[0] == values_in_shape[0], | ||
| "Keys and values input must have the same number of items"); | ||
| NVTE_CHECK(keys_out_shape[0] == values_out_shape[0], | ||
| "Keys and values output must have the same number of items"); | ||
| int num_items = static_cast<int>(keys_in_shape[0]); | ||
| int k = static_cast<int>(k_value); | ||
|
Comment on lines
+39
to
+40
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There is no check that A guard should be added here alongside the existing shape checks: NVTE_CHECK(k <= num_items, "k (", k, ") must be <= num_items (", num_items, ")"); |
||
|
|
||
| auto input_shape = std::vector<size_t>{keys_in_shape[0]}; | ||
| auto output_shape = std::vector<size_t>{keys_out_shape[0]}; | ||
| auto workspace_shape = std::vector<size_t>{workbuf_bytes}; | ||
|
|
||
| auto keys_in_tensor = TensorWrapper(keys_in_buf.untyped_data(), input_shape, keys_in_dtype); | ||
| auto values_in_tensor = TensorWrapper(values_in_buf.untyped_data(), input_shape, values_in_dtype); | ||
| auto keys_out_tensor = TensorWrapper(keys_out_buf->untyped_data(), output_shape, keys_out_dtype); | ||
| auto values_out_tensor = | ||
| TensorWrapper(values_out_buf->untyped_data(), output_shape, values_out_dtype); | ||
| auto workspace_tensor = | ||
| TensorWrapper(workspace_buf->untyped_data(), workspace_shape, DType::kByte); | ||
|
|
||
| nvte_cub_topk(stream, keys_in_tensor.data(), values_in_tensor.data(), keys_out_tensor.data(), | ||
| values_out_tensor.data(), workspace_tensor.data(), num_items, k, workbuf_bytes); | ||
|
|
||
| return ffi_with_cuda_error_check(); | ||
| } | ||
|
|
||
| XLA_FFI_DEFINE_HANDLER_SYMBOL(CubTopkHandler, CubTopkFFI, | ||
| FFI::Bind() | ||
| .Ctx<FFI_Stream_Type>() // stream | ||
| .Arg<Buffer_Type>() // keys_buf | ||
| .Arg<Buffer_Type>() // values_buf | ||
| .Ret<Buffer_Type>() // topk_buf | ||
| .Ret<Buffer_Type>() // indices_buf | ||
| .Ret<Buffer_Type>() // workspace_buf | ||
| .Attr<int64_t>("k_value") | ||
| .Attr<int64_t>("workbuf_bytes"), | ||
| FFI_CudaGraph_Traits); | ||
|
|
||
| } // namespace jax | ||
| } // namespace transformer_engine | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cub_topknot exported in__all__The user-facing function
cub_topk(defined at line 97) is not included in__all__. OnlyCubTopkPrimitiveis listed. Tools and users relying on__all__for the module's public API won't discovercub_topk. Since the test imports it as a primary API (from transformer_engine.jax.cpp_extensions.cub import cub_topk), it should be exported.