[Common][JAX] Add CUB TopK MaxPairs interface#2784
[Common][JAX] Add CUB TopK MaxPairs interface#2784huanghua1994 wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Hua Huang <huah@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR introduces a new CUB Key items to address before merging:
Confidence Score: 2/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["JAX: cub_topk call"] --> B["CubTopkPrimitive bind"]
B --> C["FFI lowering via te_cub_topk_ffi"]
C --> D["CubTopkFFI: validate dtypes and shapes"]
D --> E["nvte_cub_topk: dispatch on dtype"]
E --> F["cub DeviceTopK MaxPairs on GPU"]
F --> G["Return top-k keys and indices"]
Last reviewed commit: "[pre-commit.ci] auto..." |
| class CubTopkPrimitive(BasePrimitive): | ||
| """ | ||
| CUB Topk Primitive | ||
| """ | ||
|
|
||
| name = "te_cub_topk_ffi" | ||
| multiple_results = True | ||
| impl_static_args = (2,) # k_value | ||
| inner_primitive = None | ||
| outer_primitive = None | ||
|
|
||
| @staticmethod | ||
| def abstract( | ||
| in_keys_aval, | ||
| in_values_aval, | ||
| *, | ||
| k_value, | ||
| ): | ||
| keys_dtype = dtypes.canonicalize_dtype(in_keys_aval.dtype) | ||
| values_dtype = dtypes.canonicalize_dtype(in_values_aval.dtype) | ||
| assert keys_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] | ||
| assert values_dtype == jnp.int32 | ||
|
|
||
| workspace_bytes = get_cub_topk_workspace_bytes() | ||
| out_keys_aval = jax.core.ShapedArray(shape=(k_value,), dtype=keys_dtype) | ||
| out_values_aval = jax.core.ShapedArray(shape=(k_value,), dtype=jnp.int32) | ||
| workspace_aval = jax.core.ShapedArray(shape=(workspace_bytes,), dtype=jnp.uint8) | ||
| return (out_keys_aval, out_values_aval, workspace_aval) | ||
|
|
||
| @staticmethod | ||
| def outer_abstract(*args, **kwargs): | ||
| out_keys_aval, out_values_aval, _workspace_aval = CubTopkPrimitive.abstract(*args, **kwargs) | ||
| return (out_keys_aval, out_values_aval) | ||
|
|
||
| @staticmethod | ||
| def lowering( | ||
| ctx, | ||
| in_keys, | ||
| in_values, | ||
| k_value, | ||
| ): | ||
| workspace_bytes = get_cub_topk_workspace_bytes() | ||
| return ffi.ffi_lowering( | ||
| CubTopkPrimitive.name, | ||
| )( | ||
| ctx, | ||
| in_keys, | ||
| in_values, | ||
| k_value=k_value, | ||
| workbuf_bytes=workspace_bytes, | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def impl( | ||
| in_keys, | ||
| in_values, | ||
| k_value, | ||
| ): | ||
| assert CubTopkPrimitive.inner_primitive is not None | ||
| out_keys, out_values, _workspace = CubTopkPrimitive.inner_primitive.bind( | ||
| in_keys, | ||
| in_values, | ||
| k_value=k_value, | ||
| ) | ||
| return (out_keys, out_values) | ||
|
|
||
|
|
||
| register_primitive(CubTopkPrimitive) |
There was a problem hiding this comment.
Missing
batcher, partition, and shardy_sharding_rule methods
CubTopkPrimitive extends BasePrimitive, which declares batcher(), partition(), and shardy_sharding_rule() as abstract methods. CubTopkPrimitive does not implement any of them.
When register_primitive(CubTopkPrimitive) is called, base.py does:
batching.primitive_batchers[outer_p] = cls.batcher # resolves to abstract method → returns NotImplemented
outer_p_lower.def_partition(partition=cls.partition, ...) # sameThis means:
- Any attempt to use
vmapovercub_topkwill fail at runtime because the registered batcher is the abstract method (which returnsNotImplemented, not a callable that returns batched results). - Multi-device / sharding via
custom_partitioningwill similarly fail whenpartitionis invoked.
Every other primitive in the codebase (e.g., in router.py) implements all three of these methods. If sharding and batching are intentionally unsupported for now, the methods should at minimum raise a clear NotImplementedError (rather than silently returning NotImplemented), and this limitation should be documented.
| def get_cub_topk_workspace_bytes() -> int: | ||
| """ | ||
| Get the workspace size for CUB Topk | ||
| The safe way is calling the CUB kernel to query the workspace size. | ||
| For convenience, we use a heuristic value based on experiments. | ||
| 4 MiB is enough for N up to 5,000,000 and K up to 100,000. | ||
| """ | ||
| return 4 * 1024 * 1024 |
There was a problem hiding this comment.
Hardcoded workspace size may silently corrupt memory for large inputs
get_cub_topk_workspace_bytes() always returns a fixed 4 MiB and the docstring itself acknowledges this only covers "N up to 5,000,000 and K up to 100,000." However, there is no validation in the Python or C++ layer that the user's actual N and K do not exceed these limits.
If a caller passes N > 5_000_000 or K > 100_000, cub::DeviceTopK::MaxPairs will be given an undersized workspace buffer and will write out-of-bounds on the GPU — a silent CUDA memory corruption with no error raised back to the caller.
The correct approach is to call cub::DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, ...) with a null workspace pointer to query the required size at runtime, then allocate that exact amount. The current heuristic should at minimum be accompanied by a runtime guard that raises an error when the inputs exceed the documented limits.
| int num_items = static_cast<int>(keys_in_shape[0]); | ||
| int k = static_cast<int>(k_value); |
There was a problem hiding this comment.
No validation that
k <= num_items
There is no check that k_value is less than or equal to num_items (the size of the input array). CUB's DeviceTopK::MaxPairs requires k <= num_items; if k > num_items the behavior is undefined and will likely produce a CUDA error or garbage output.
A guard should be added here alongside the existing shape checks:
NVTE_CHECK(k <= num_items, "k (", k, ") must be <= num_items (", num_items, ")");|
|
||
| from .base import BasePrimitive, register_primitive | ||
|
|
||
| __all__ = ["CubTopkPrimitive"] |
There was a problem hiding this comment.
Public function
cub_topk not exported in __all__
The user-facing function cub_topk (defined at line 97) is not included in __all__. Only CubTopkPrimitive is listed. Tools and users relying on __all__ for the module's public API won't discover cub_topk. Since the test imports it as a primary API (from transformer_engine.jax.cpp_extensions.cub import cub_topk), it should be exported.
| __all__ = ["CubTopkPrimitive"] | |
| __all__ = ["CubTopkPrimitive", "cub_topk"] |
Description
This PR introduces the new CUB TopK API for large N and K values.
Type of change
Changes
3rdparty/ccclas a dependency since the CTK on the machine might not be new enoughtransformer_engine/common/util/cub.cuas the enter point to the CUB TopK functionChecklist: