Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
[submodule "3rdparty/cutlass"]
path = 3rdparty/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "3rdparty/cccl"]
path = 3rdparty/cccl
url = https://github.com/NVIDIA/cccl.git
1 change: 1 addition & 0 deletions 3rdparty/cccl
Submodule cccl added at c262ef
40 changes: 40 additions & 0 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense, grouped_dense
from transformer_engine.jax.layernorm_dense import layernorm_dense
from transformer_engine.jax.cpp_extensions.cub import cub_topk

GEMM_CASES = [
(256, 256, 512),
Expand Down Expand Up @@ -1955,3 +1956,42 @@ def f(x):
actual = load_array_dump("my_tensor_gpu0.bin", shape, dtype)

assert_allclose(actual, expected, dtype=dtype)


@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16, jnp.float32])
@pytest.mark.parametrize(
"problem_size", [(10000, 100), (50000, 200), (100000, 500), (1000000, 1000), (5000000, 2000)]
)
class TestCubOps:
def test_cub_topk(self, dtype, problem_size):
n, k = problem_size

prng_key = jax.random.PRNGKey(0)
keys = jax.random.split(prng_key, 3)
topk_values = jax.random.uniform(keys[0], shape=(k,), dtype=dtype, minval=1.5, maxval=2.5)
bottom_values = jax.random.uniform(
keys[1], shape=(n - k,), dtype=dtype, minval=0.0, maxval=1.0
)
x = jnp.concatenate([topk_values, bottom_values])
x = jax.random.permutation(keys[2], x)

ref_topk_jit = jax.jit(jax.lax.top_k, static_argnums=(1,))
prim_topk_jit = jax.jit(cub_topk, static_argnums=(1,))

ref_topk, ref_indices = ref_topk_jit(x, k)
prim_topk, prim_indices = prim_topk_jit(x, k)

# CUB output does not guarantee the order of the topk values, sort them for comparison
ref_topk, ref_indices = jax.lax.sort_key_val(ref_topk, ref_indices)
prim_topk, prim_indices = jax.lax.sort_key_val(prim_topk, prim_indices)

assert_allclose(ref_topk, prim_topk, dtype=dtype)

# sort and sort_key_val are ascending, make sure the smallest topk value
# prim_topk[0] is not smaller than the k+1 largest value in the original array
sorted_x = jax.lax.sort(x)
assert prim_topk[0] >= sorted_x[-(k + 1)]

# TopK values can be duplicated, instead of directly comparing the indices, we check
# if the values at the returned indices are the same
assert_allclose(x[ref_indices], x[prim_indices], dtype=dtype)
16 changes: 15 additions & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ set(CUTLASS_INCLUDE_DIR
set(CUTLASS_TOOLS_INCLUDE_DIR
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/tools/util/include")

# CCCL (CUDA Core Compute Libraries)
set(CCCL_INCLUDE_DIR
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cccl")
if(NOT EXISTS "${CCCL_INCLUDE_DIR}")
message(FATAL_ERROR
"Could not find CCCL at ${CCCL_INCLUDE_DIR}. "
"Try running 'git submodule update --init --recursive' "
"within the Transformer Engine source.")
endif()

# Python
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)

Expand Down Expand Up @@ -151,6 +161,7 @@ list(APPEND transformer_engine_cuda_sources
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
util/padding.cu
util/cub.cu
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
Expand Down Expand Up @@ -262,8 +273,11 @@ target_link_libraries(transformer_engine PUBLIC

target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
# Use CCCL from 3rdparty instead of the one from CUDA Toolkit
target_include_directories(transformer_engine SYSTEM PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl)
${CCCL_INCLUDE_DIR}/thrust
${CCCL_INCLUDE_DIR}/cub
${CCCL_INCLUDE_DIR}/libcudacxx/include)
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
target_include_directories(transformer_engine PRIVATE
${CUTLASS_INCLUDE_DIR}
Expand Down
39 changes: 39 additions & 0 deletions transformer_engine/common/include/transformer_engine/cub.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#ifndef TRANSFORMER_ENGINE_CUB_H_
#define TRANSFORMER_ENGINE_CUB_H_

#include "transformer_engine.h"

#ifdef __cplusplus
extern "C" {
#endif

/*! \brief Compute the top-K largest (key, value) pairs using CUB.
*
* \param[in] stream CUDA stream used for the operation.
* \param[in] keys_in Input 1D keys tensor, shape (num_items,)
* \param[in] values_in Input 1D values tensor, shape (num_items,)
* \param[in,out] keys_out Output 1D keys tensor, shape (k,)
* \param[in,out] values_out Output 1D values tensor, shape (k,)
* \param[in,out] workspace Workspace tensor, shape (workspace_bytes,)
* \param[in] num_items Number of items in the input tensor
* \param[in] k Number of top-K largest values to return
* \param[in] workspace_bytes Workspace size in bytes
*
* Requirements:
* - Only supports float32, float16, bfloat16 keys and int32 values.
*/
void nvte_cub_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor values_in,
NVTETensor keys_out, NVTETensor values_out, NVTETensor workspace,
const int num_items, const int k, const size_t workspace_bytes);

#ifdef __cplusplus
} // extern "C"
#endif

#endif
54 changes: 54 additions & 0 deletions transformer_engine/common/util/cub.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include <transformer_engine/cub.h>

#include <cub/device/device_topk.cuh>
#include <cuda/std/execution>

#include "../common.h"

void nvte_cub_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor values_in,
NVTETensor keys_out, NVTETensor values_out, NVTETensor workspace, int num_items,
int k, size_t workspace_bytes) {
NVTE_API_CALL(nvte_cub_topk);
using namespace transformer_engine;

const Tensor *keys_in_tensor = convertNVTETensorCheck(keys_in);
const Tensor *values_in_tensor = convertNVTETensorCheck(values_in);
Tensor *keys_out_tensor = convertNVTETensor(keys_out);
Tensor *values_out_tensor = convertNVTETensor(values_out);
Tensor *workspace_tensor = convertNVTETensor(workspace);
auto keys_in_dtype = keys_in_tensor->data.dtype;
auto values_in_dtype = values_in_tensor->data.dtype;

auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed,
cuda::execution::output_ordering::unsorted);
cuda::stream_ref stream_ref{stream};
auto env = cuda::std::execution::env{stream_ref, requirements};

#define DISPATCH_CUB_TOPK(KeyT, ValueT) \
do { \
KeyT *d_keys_in = reinterpret_cast<KeyT *>(keys_in_tensor->data.dptr); \
KeyT *d_keys_out = reinterpret_cast<KeyT *>(keys_out_tensor->data.dptr); \
ValueT *d_values_in = reinterpret_cast<ValueT *>(values_in_tensor->data.dptr); \
ValueT *d_values_out = reinterpret_cast<ValueT *>(values_out_tensor->data.dptr); \
void *d_workspace = reinterpret_cast<void *>(workspace_tensor->data.dptr); \
cub::DeviceTopK::MaxPairs(d_workspace, workspace_bytes, d_keys_in, d_keys_out, d_values_in, \
d_values_out, num_items, k, env); \
} while (0);

if (keys_in_dtype == DType::kFloat32 && values_in_dtype == DType::kInt32) {
DISPATCH_CUB_TOPK(float, int);
} else if (keys_in_dtype == DType::kFloat16 && values_in_dtype == DType::kInt32) {
DISPATCH_CUB_TOPK(__half, int);
} else if (keys_in_dtype == DType::kBFloat16 && values_in_dtype == DType::kInt32) {
DISPATCH_CUB_TOPK(__nv_bfloat16, int);
} else {
NVTE_ERROR("Unsupported input key and value data types");
}
#undef DISPATCH_CUB_TOPK
}
111 changes: 111 additions & 0 deletions transformer_engine/jax/cpp_extensions/cub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""CUB custom ops"""

from typing import Tuple

import jax
import jax.numpy as jnp
from jax import dtypes, ffi

from .base import BasePrimitive, register_primitive

__all__ = ["CubTopkPrimitive"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Public function cub_topk not exported in __all__

The user-facing function cub_topk (defined at line 97) is not included in __all__. Only CubTopkPrimitive is listed. Tools and users relying on __all__ for the module's public API won't discover cub_topk. Since the test imports it as a primary API (from transformer_engine.jax.cpp_extensions.cub import cub_topk), it should be exported.

Suggested change
__all__ = ["CubTopkPrimitive"]
__all__ = ["CubTopkPrimitive", "cub_topk"]



def get_cub_topk_workspace_bytes() -> int:
"""
Get the workspace size for CUB Topk
The safe way is calling the CUB kernel to query the workspace size.
For convenience, we use a heuristic value based on experiments.
4 MiB is enough for N up to 5,000,000 and K up to 100,000.
"""
return 4 * 1024 * 1024
Comment on lines +17 to +24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Hardcoded workspace size may silently corrupt memory for large inputs

get_cub_topk_workspace_bytes() always returns a fixed 4 MiB and the docstring itself acknowledges this only covers "N up to 5,000,000 and K up to 100,000." However, there is no validation in the Python or C++ layer that the user's actual N and K do not exceed these limits.

If a caller passes N > 5_000_000 or K > 100_000, cub::DeviceTopK::MaxPairs will be given an undersized workspace buffer and will write out-of-bounds on the GPU — a silent CUDA memory corruption with no error raised back to the caller.

The correct approach is to call cub::DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, ...) with a null workspace pointer to query the required size at runtime, then allocate that exact amount. The current heuristic should at minimum be accompanied by a runtime guard that raises an error when the inputs exceed the documented limits.



class CubTopkPrimitive(BasePrimitive):
"""
CUB Topk Primitive
"""

name = "te_cub_topk_ffi"
multiple_results = True
impl_static_args = (2,) # k_value
inner_primitive = None
outer_primitive = None

@staticmethod
def abstract(
in_keys_aval,
in_values_aval,
*,
k_value,
):
keys_dtype = dtypes.canonicalize_dtype(in_keys_aval.dtype)
values_dtype = dtypes.canonicalize_dtype(in_values_aval.dtype)
assert keys_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert values_dtype == jnp.int32

workspace_bytes = get_cub_topk_workspace_bytes()
out_keys_aval = jax.core.ShapedArray(shape=(k_value,), dtype=keys_dtype)
out_values_aval = jax.core.ShapedArray(shape=(k_value,), dtype=jnp.int32)
workspace_aval = jax.core.ShapedArray(shape=(workspace_bytes,), dtype=jnp.uint8)
return (out_keys_aval, out_values_aval, workspace_aval)

@staticmethod
def outer_abstract(*args, **kwargs):
out_keys_aval, out_values_aval, _workspace_aval = CubTopkPrimitive.abstract(*args, **kwargs)
return (out_keys_aval, out_values_aval)

@staticmethod
def lowering(
ctx,
in_keys,
in_values,
k_value,
):
workspace_bytes = get_cub_topk_workspace_bytes()
return ffi.ffi_lowering(
CubTopkPrimitive.name,
)(
ctx,
in_keys,
in_values,
k_value=k_value,
workbuf_bytes=workspace_bytes,
)

@staticmethod
def impl(
in_keys,
in_values,
k_value,
):
assert CubTopkPrimitive.inner_primitive is not None
out_keys, out_values, _workspace = CubTopkPrimitive.inner_primitive.bind(
in_keys,
in_values,
k_value=k_value,
)
return (out_keys, out_values)


register_primitive(CubTopkPrimitive)
Comment on lines +27 to +94
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing batcher, partition, and shardy_sharding_rule methods

CubTopkPrimitive extends BasePrimitive, which declares batcher(), partition(), and shardy_sharding_rule() as abstract methods. CubTopkPrimitive does not implement any of them.

When register_primitive(CubTopkPrimitive) is called, base.py does:

batching.primitive_batchers[outer_p] = cls.batcher  # resolves to abstract method → returns NotImplemented
outer_p_lower.def_partition(partition=cls.partition, ...)  # same

This means:

  • Any attempt to use vmap over cub_topk will fail at runtime because the registered batcher is the abstract method (which returns NotImplemented, not a callable that returns batched results).
  • Multi-device / sharding via custom_partitioning will similarly fail when partition is invoked.

Every other primitive in the codebase (e.g., in router.py) implements all three of these methods. If sharding and batching are intentionally unsupported for now, the methods should at minimum raise a clear NotImplementedError (rather than silently returning NotImplemented), and this limitation should be documented.



def cub_topk(
x: jnp.ndarray,
k_value: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
CUB Topk max pairs
"""
keys = x
values = jnp.arange(x.shape[0], dtype=jnp.int32)
out_keys, out_values = CubTopkPrimitive.outer_primitive.bind(
keys,
values,
k_value=k_value,
)
return out_keys, out_values
3 changes: 3 additions & 0 deletions transformer_engine/jax/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler);

// Cub Topk
XLA_FFI_DECLARE_HANDLER_SYMBOL(CubTopkHandler);

} // namespace jax
} // namespace transformer_engine

Expand Down
73 changes: 73 additions & 0 deletions transformer_engine/jax/csrc/extensions/cub.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include "transformer_engine/cub.h"

#include "../extensions.h"
#include "xla/ffi/api/c_api.h"

namespace transformer_engine {
namespace jax {

Error_Type CubTopkFFI(cudaStream_t stream, Buffer_Type keys_in_buf, Buffer_Type values_in_buf,
Result_Type keys_out_buf, Result_Type values_out_buf,
Result_Type workspace_buf, int64_t k_value, int64_t workbuf_bytes) {
auto keys_in_dtype = convert_ffi_datatype_to_te_dtype(keys_in_buf.element_type());
auto values_in_dtype = convert_ffi_datatype_to_te_dtype(values_in_buf.element_type());
auto keys_out_dtype = convert_ffi_datatype_to_te_dtype(keys_out_buf->element_type());
auto values_out_dtype = convert_ffi_datatype_to_te_dtype(values_out_buf->element_type());
NVTE_CHECK(keys_in_dtype == keys_out_dtype, "Input and output keys must have the same datatype");
NVTE_CHECK(values_in_dtype == values_out_dtype,
"Input and output values must have the same datatype");
NVTE_CHECK(values_in_dtype == DType::kInt32, "CubTopkFFI() only supports int32 values for now");

auto keys_in_shape = keys_in_buf.dimensions();
auto values_in_shape = values_in_buf.dimensions();
auto keys_out_shape = keys_out_buf->dimensions();
auto values_out_shape = values_out_buf->dimensions();
NVTE_CHECK(keys_in_shape.size() == 1, "Keys input must have 1 dimension");
NVTE_CHECK(values_in_shape.size() == 1, "Values input must have 1 dimension");
NVTE_CHECK(keys_out_shape.size() == 1, "Keys output must have 1 dimension");
NVTE_CHECK(values_out_shape.size() == 1, "Values output must have 1 dimension");
NVTE_CHECK(keys_in_shape[0] == values_in_shape[0],
"Keys and values input must have the same number of items");
NVTE_CHECK(keys_out_shape[0] == values_out_shape[0],
"Keys and values output must have the same number of items");
int num_items = static_cast<int>(keys_in_shape[0]);
int k = static_cast<int>(k_value);
Comment on lines +39 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 No validation that k <= num_items

There is no check that k_value is less than or equal to num_items (the size of the input array). CUB's DeviceTopK::MaxPairs requires k <= num_items; if k > num_items the behavior is undefined and will likely produce a CUDA error or garbage output.

A guard should be added here alongside the existing shape checks:

NVTE_CHECK(k <= num_items, "k (", k, ") must be <= num_items (", num_items, ")");


auto input_shape = std::vector<size_t>{keys_in_shape[0]};
auto output_shape = std::vector<size_t>{keys_out_shape[0]};
auto workspace_shape = std::vector<size_t>{workbuf_bytes};

auto keys_in_tensor = TensorWrapper(keys_in_buf.untyped_data(), input_shape, keys_in_dtype);
auto values_in_tensor = TensorWrapper(values_in_buf.untyped_data(), input_shape, values_in_dtype);
auto keys_out_tensor = TensorWrapper(keys_out_buf->untyped_data(), output_shape, keys_out_dtype);
auto values_out_tensor =
TensorWrapper(values_out_buf->untyped_data(), output_shape, values_out_dtype);
auto workspace_tensor =
TensorWrapper(workspace_buf->untyped_data(), workspace_shape, DType::kByte);

nvte_cub_topk(stream, keys_in_tensor.data(), values_in_tensor.data(), keys_out_tensor.data(),
values_out_tensor.data(), workspace_tensor.data(), num_items, k, workbuf_bytes);

return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(CubTopkHandler, CubTopkFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // keys_buf
.Arg<Buffer_Type>() // values_buf
.Ret<Buffer_Type>() // topk_buf
.Ret<Buffer_Type>() // indices_buf
.Ret<Buffer_Type>() // workspace_buf
.Attr<int64_t>("k_value")
.Attr<int64_t>("workbuf_bytes"),
FFI_CudaGraph_Traits);

} // namespace jax
} // namespace transformer_engine
3 changes: 3 additions & 0 deletions transformer_engine/jax/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ pybind11::dict Registrations() {
dict["te_fused_moe_aux_loss_forward_ffi"] = EncapsulateFFI(FusedMoEAuxLossForwardHandler);
dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler);

// Cub Topk
dict["te_cub_topk_ffi"] = EncapsulateFFI(CubTopkHandler);

return dict;
}

Expand Down