Skip to content

Conversation

@cliffburdick
Copy link
Collaborator

Add cuBLASDx support for JIT-compiled matrix multiplication

Integrate cuBLASDx for fusion and accelerated matrix
multiplication of small matrices that fit in shared memory. This enables
significantly faster GEMM operations for sizes up to ~200 elements per
dimension (varies by data type and compute capability).

Key changes:

  • Add matmul_cublasdx.h with cuBLASDxHelper class for managing GEMM
    parameters, size validation, and device code generation
  • Extend MatMulOp with JIT storage support and cuBLASDx-specific code
    generation (get_jit_class_name, get_jit_op_str)
  • Add PASS_THROUGH_THREADS capability for operators where all threads
    must invoke operator() with bounds checking at the tensor level
  • Update JIT executor to handle 2D block launch configuration for
    cuBLASDx operators with fixed block dimensions
  • Add Block2D kernel variants (matxOpT{2,3,4}KernelBlock2D) for
    pass-through thread execution model

Supported types: half, bfloat16, float, double, and their complex
variants. Size limits are architecture-dependent, ranging from 36-196
elements per dimension based on compute capability (SM 7.0 - SM 11.0).

Requires MATX_EN_MATHDX to be enabled at compile time.

Note that this is in early development. cuBLASDx has limitations that affect the MatX code base as a whole, such as dictating what the block size should be. This PR is for early adopter support and we will add more features over time.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 23, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@cliffburdick
Copy link
Collaborator Author

/build

@greptile-apps
Copy link

greptile-apps bot commented Jan 23, 2026

Greptile Summary

  • Integrates cuBLASDx library for JIT-compiled matrix multiplication of small matrices that fit in shared memory, enabling significantly faster GEMM operations for sizes up to ~200 elements per dimension
  • Introduces PASS_THROUGH_THREADS capability to support operators requiring all threads to execute with bounds checking at the tensor level rather than per-thread
  • Extends JIT compilation infrastructure to handle 2D block kernels with fixed dimensions and adds comprehensive cuBLASDx helper class for parameter management and device code generation

Important Files Changed

Filename Overview
include/matx/core/error.h Added new error code but missing corresponding error string mapping, causing undefined behavior
include/matx/transforms/matmul/matmul_cublasdx.h New cuBLASDx integration file with potential issues in shared memory calculations and output indexing logic
include/matx/operators/matmul.h Major refactor adding dual compilation paths and complex capability system for cuBLASDx support
include/matx/core/tensor_impl.h Added pass-through bounds checking using static variables in device code which could cause multi-GPU issues
include/matx/executors/jit_cuda.h Extended JIT executor with missing initialization in ND kernel path that could lead to inconsistent cache state

Confidence score: 2/5

  • This PR requires careful review due to several critical bugs and architectural risks that could cause runtime failures
  • Score lowered due to missing error string mapping causing undefined behavior, potential memory access issues in bounds checking, inconsistent shared memory calculations, and uninitialized variables in cache logic
  • Pay close attention to error handling in include/matx/core/error.h, bounds checking logic in include/matx/core/tensor_impl.h, and the cuBLASDx helper implementation in include/matx/transforms/matmul/matmul_cublasdx.h

Sequence Diagram

sequenceDiagram
    participant User
    participant MatMulOp
    participant CUDAJITExecutor
    participant cuBLASDxHelper
    participant NVRTC
    participant CUDARuntime

    User->>MatMulOp: "matmul(A, B)"
    MatMulOp->>cuBLASDxHelper: "Initialize with matrix dimensions (m,n,k)"
    MatMulOp->>cuBLASDxHelper: "Set compute capability"
    
    User->>CUDAJITExecutor: "Exec(matmul_op)"
    CUDAJITExecutor->>MatMulOp: "get_capability<SUPPORTS_JIT>()"
    MatMulOp->>cuBLASDxHelper: "CheckJITSizeAndTypeRequirements()"
    cuBLASDxHelper-->>MatMulOp: "supported = true/false"
    MatMulOp-->>CUDAJITExecutor: "JIT supported"
    
    CUDAJITExecutor->>MatMulOp: "get_capability<PASS_THROUGH_THREADS>()"
    MatMulOp-->>CUDAJITExecutor: "true (block-level cooperation needed)"
    
    CUDAJITExecutor->>MatMulOp: "get_capability<BLOCK_DIM>()"
    MatMulOp->>cuBLASDxHelper: "GetBlockDim()"
    cuBLASDxHelper-->>MatMulOp: "block dimensions"
    MatMulOp-->>CUDAJITExecutor: "fixed block size"
    
    CUDAJITExecutor->>CUDAJITExecutor: "get_grid_dims_block_2d()"
    
    CUDAJITExecutor->>MatMulOp: "get_capability<JIT_CLASS_QUERY>()"
    MatMulOp->>MatMulOp: "get_jit_op_str()"
    MatMulOp-->>CUDAJITExecutor: "JIT class definitions"
    
    CUDAJITExecutor->>MatMulOp: "get_capability<GENERATE_LTOIR>()"
    MatMulOp->>cuBLASDxHelper: "GenerateLTOIR()"
    cuBLASDxHelper->>cuBLASDxHelper: "GeneratePlan() with cuBLASDx"
    cuBLASDxHelper->>cuBLASDxHelper: "Generate LTOIR binary"
    cuBLASDxHelper-->>MatMulOp: "LTOIR symbols added"
    MatMulOp-->>CUDAJITExecutor: "LTOIR generated"
    
    CUDAJITExecutor->>NVRTC: "nvrtcCreateProgram() with JIT kernel source"
    CUDAJITExecutor->>NVRTC: "nvrtcCompileProgram() with includes"
    NVRTC-->>CUDAJITExecutor: "Compiled LTOIR"
    
    CUDAJITExecutor->>CUDAJITExecutor: "nvJitLinkCreate() and add LTOIR files"
    CUDAJITExecutor->>CUDAJITExecutor: "nvJitLinkComplete() to generate cubin"
    
    CUDAJITExecutor->>CUDARuntime: "cuModuleLoadDataEx() with linked cubin"
    CUDARuntime-->>CUDAJITExecutor: "CUDA module"
    
    CUDAJITExecutor->>CUDARuntime: "cuFuncSetAttribute() for dynamic shared memory"
    CUDAJITExecutor->>CUDARuntime: "cuLaunchKernel() with 2D block configuration"
    
    CUDARuntime->>CUDARuntime: "Execute kernel with block-level GEMM cooperation"
    CUDARuntime-->>User: "Matrix multiplication result"
Loading

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. include/matx/core/error.h, line 95-96 (link)

    logic: Missing case for matxLibMathdxError in the switch statement. This will cause the function to fall through to the default case and return "Unknown" instead of a proper error string.

  2. include/matx/executors/jit_cuda.h, line 319 (link)

    logic: Missing initialization of pass_through_threads in ND kernel cache parameters - should be set to false for consistency

11 files reviewed, 11 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +329 to +337
else if constexpr (RANK > 4) {
// For higher ranks, flatten batch dimensions into available grid dims
// This may need stride handling for very large batches
index_t total_batches = 1;
for (int i = 0; i < RANK - 2; i++) {
total_batches *= sizes[i];
}
blocks.x = static_cast<int>(total_batches);
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: For RANK > 4, all batch dimensions are flattened into blocks.x only, ignoring blocks.y and blocks.z. This could exceed CUDA grid limits (65535) for large batch sizes.

Comment on lines +341 to +342
// No stride needed for now - could be extended for very large batches
return false;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Function always returns false, but the comment suggests stride handling may be needed for large batches. This could cause issues if grid limits are exceeded.

" static_assert(sizeof...(Is) == M, \"Number of indices of operator() must match rank of tensor\");\n" +
" constexpr int EPT_int = static_cast<int>(CapType::ept);\n" +
" if constexpr (CapType::pass_through_threads) {\n" +
" static cuda::std::conditional_t<CapType::ept == detail::ElementsPerThread::ONE, T, detail::Vector<T, EPT_int>> dummy_;\n" +
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Static variable in device code can cause issues with CUDA compilation and runtime behavior across different GPU contexts. Have you tested this static variable behavior across different CUDA contexts and GPU devices?

" template <typename CapType, int M = RANK, typename... Is,\n" +
" cuda::std::enable_if_t<cuda::std::conjunction_v<cuda::std::is_integral<Is>...>, bool> = true>\n" +
" __MATX_INLINE__ __MATX_DEVICE__ decltype(auto) operator()(Is... indices) const noexcept" + "{\n" +
" __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is... indices) const noexcept" + "{\n" +
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Return type changed from decltype(auto) to auto which may affect template argument deduction in some contexts

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +211 to +216
#if defined(CUDA_VERSION)
symbol_name += "_CUDA";
symbol_name += std::to_string(CUDART_VERSION);
#else
symbol_name += "_CUDAUNKNOWN";
#endif
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Using CUDA_VERSION check but referencing CUDART_VERSION in the string. If CUDA_VERSION is undefined, CUDART_VERSION may also be undefined, potentially causing compilation issues.

Suggested change
#if defined(CUDA_VERSION)
symbol_name += "_CUDA";
symbol_name += std::to_string(CUDART_VERSION);
#else
symbol_name += "_CUDAUNKNOWN";
#endif
#if defined(CUDART_VERSION)
symbol_name += "_CUDA";
symbol_name += std::to_string(CUDART_VERSION);
#else
symbol_name += "_CUDAUNKNOWN";
#endif

Comment on lines +423 to +425
if constexpr (CapType::ept == ElementsPerThread::ONE) {
const int output_idx = threadIdx.x;
return smem_c[output_idx];
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Output indexing logic assumes threadIdx.x maps directly to output elements, but this may not align with the 2D block dimensions suggested by cuBLASDx. The logic doesn't account for threadIdx.y or threadIdx.z. How should the output indexing work with 2D/3D block dimensions from cuBLASDx?

Comment on lines +414 to +415
result += "value_type alpha_val = static_cast<value_type>(" + std::to_string(alpha) + ");\n";
result += "value_type beta_val = static_cast<value_type>(" + std::to_string(beta) + ");\n";
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Float literals for alpha/beta values may cause precision loss when InputType is double or half precision. The conversion should preserve the original precision.

@cliffburdick
Copy link
Collaborator Author

/build

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant