-
Notifications
You must be signed in to change notification settings - Fork 111
cuBLASDx support #1122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
cuBLASDx support #1122
Conversation
|
/build |
Greptile Summary
Important Files Changed
Confidence score: 2/5
Sequence DiagramsequenceDiagram
participant User
participant MatMulOp
participant CUDAJITExecutor
participant cuBLASDxHelper
participant NVRTC
participant CUDARuntime
User->>MatMulOp: "matmul(A, B)"
MatMulOp->>cuBLASDxHelper: "Initialize with matrix dimensions (m,n,k)"
MatMulOp->>cuBLASDxHelper: "Set compute capability"
User->>CUDAJITExecutor: "Exec(matmul_op)"
CUDAJITExecutor->>MatMulOp: "get_capability<SUPPORTS_JIT>()"
MatMulOp->>cuBLASDxHelper: "CheckJITSizeAndTypeRequirements()"
cuBLASDxHelper-->>MatMulOp: "supported = true/false"
MatMulOp-->>CUDAJITExecutor: "JIT supported"
CUDAJITExecutor->>MatMulOp: "get_capability<PASS_THROUGH_THREADS>()"
MatMulOp-->>CUDAJITExecutor: "true (block-level cooperation needed)"
CUDAJITExecutor->>MatMulOp: "get_capability<BLOCK_DIM>()"
MatMulOp->>cuBLASDxHelper: "GetBlockDim()"
cuBLASDxHelper-->>MatMulOp: "block dimensions"
MatMulOp-->>CUDAJITExecutor: "fixed block size"
CUDAJITExecutor->>CUDAJITExecutor: "get_grid_dims_block_2d()"
CUDAJITExecutor->>MatMulOp: "get_capability<JIT_CLASS_QUERY>()"
MatMulOp->>MatMulOp: "get_jit_op_str()"
MatMulOp-->>CUDAJITExecutor: "JIT class definitions"
CUDAJITExecutor->>MatMulOp: "get_capability<GENERATE_LTOIR>()"
MatMulOp->>cuBLASDxHelper: "GenerateLTOIR()"
cuBLASDxHelper->>cuBLASDxHelper: "GeneratePlan() with cuBLASDx"
cuBLASDxHelper->>cuBLASDxHelper: "Generate LTOIR binary"
cuBLASDxHelper-->>MatMulOp: "LTOIR symbols added"
MatMulOp-->>CUDAJITExecutor: "LTOIR generated"
CUDAJITExecutor->>NVRTC: "nvrtcCreateProgram() with JIT kernel source"
CUDAJITExecutor->>NVRTC: "nvrtcCompileProgram() with includes"
NVRTC-->>CUDAJITExecutor: "Compiled LTOIR"
CUDAJITExecutor->>CUDAJITExecutor: "nvJitLinkCreate() and add LTOIR files"
CUDAJITExecutor->>CUDAJITExecutor: "nvJitLinkComplete() to generate cubin"
CUDAJITExecutor->>CUDARuntime: "cuModuleLoadDataEx() with linked cubin"
CUDARuntime-->>CUDAJITExecutor: "CUDA module"
CUDAJITExecutor->>CUDARuntime: "cuFuncSetAttribute() for dynamic shared memory"
CUDAJITExecutor->>CUDARuntime: "cuLaunchKernel() with 2D block configuration"
CUDARuntime->>CUDARuntime: "Execute kernel with block-level GEMM cooperation"
CUDARuntime-->>User: "Matrix multiplication result"
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
-
include/matx/core/error.h, line 95-96 (link)logic: Missing case for
matxLibMathdxErrorin the switch statement. This will cause the function to fall through to the default case and return "Unknown" instead of a proper error string. -
include/matx/executors/jit_cuda.h, line 319 (link)logic: Missing initialization of
pass_through_threadsin ND kernel cache parameters - should be set tofalsefor consistency
11 files reviewed, 11 comments
| else if constexpr (RANK > 4) { | ||
| // For higher ranks, flatten batch dimensions into available grid dims | ||
| // This may need stride handling for very large batches | ||
| index_t total_batches = 1; | ||
| for (int i = 0; i < RANK - 2; i++) { | ||
| total_batches *= sizes[i]; | ||
| } | ||
| blocks.x = static_cast<int>(total_batches); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: For RANK > 4, all batch dimensions are flattened into blocks.x only, ignoring blocks.y and blocks.z. This could exceed CUDA grid limits (65535) for large batch sizes.
| // No stride needed for now - could be extended for very large batches | ||
| return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Function always returns false, but the comment suggests stride handling may be needed for large batches. This could cause issues if grid limits are exceeded.
| " static_assert(sizeof...(Is) == M, \"Number of indices of operator() must match rank of tensor\");\n" + | ||
| " constexpr int EPT_int = static_cast<int>(CapType::ept);\n" + | ||
| " if constexpr (CapType::pass_through_threads) {\n" + | ||
| " static cuda::std::conditional_t<CapType::ept == detail::ElementsPerThread::ONE, T, detail::Vector<T, EPT_int>> dummy_;\n" + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Static variable in device code can cause issues with CUDA compilation and runtime behavior across different GPU contexts. Have you tested this static variable behavior across different CUDA contexts and GPU devices?
| " template <typename CapType, int M = RANK, typename... Is,\n" + | ||
| " cuda::std::enable_if_t<cuda::std::conjunction_v<cuda::std::is_integral<Is>...>, bool> = true>\n" + | ||
| " __MATX_INLINE__ __MATX_DEVICE__ decltype(auto) operator()(Is... indices) const noexcept" + "{\n" + | ||
| " __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is... indices) const noexcept" + "{\n" + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Return type changed from decltype(auto) to auto which may affect template argument deduction in some contexts
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| #if defined(CUDA_VERSION) | ||
| symbol_name += "_CUDA"; | ||
| symbol_name += std::to_string(CUDART_VERSION); | ||
| #else | ||
| symbol_name += "_CUDAUNKNOWN"; | ||
| #endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Using CUDA_VERSION check but referencing CUDART_VERSION in the string. If CUDA_VERSION is undefined, CUDART_VERSION may also be undefined, potentially causing compilation issues.
| #if defined(CUDA_VERSION) | |
| symbol_name += "_CUDA"; | |
| symbol_name += std::to_string(CUDART_VERSION); | |
| #else | |
| symbol_name += "_CUDAUNKNOWN"; | |
| #endif | |
| #if defined(CUDART_VERSION) | |
| symbol_name += "_CUDA"; | |
| symbol_name += std::to_string(CUDART_VERSION); | |
| #else | |
| symbol_name += "_CUDAUNKNOWN"; | |
| #endif |
| if constexpr (CapType::ept == ElementsPerThread::ONE) { | ||
| const int output_idx = threadIdx.x; | ||
| return smem_c[output_idx]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Output indexing logic assumes threadIdx.x maps directly to output elements, but this may not align with the 2D block dimensions suggested by cuBLASDx. The logic doesn't account for threadIdx.y or threadIdx.z. How should the output indexing work with 2D/3D block dimensions from cuBLASDx?
| result += "value_type alpha_val = static_cast<value_type>(" + std::to_string(alpha) + ");\n"; | ||
| result += "value_type beta_val = static_cast<value_type>(" + std::to_string(beta) + ");\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Float literals for alpha/beta values may cause precision loss when InputType is double or half precision. The conversion should preserve the original precision.
|
/build |
Add cuBLASDx support for JIT-compiled matrix multiplication
Integrate cuBLASDx for fusion and accelerated matrix
multiplication of small matrices that fit in shared memory. This enables
significantly faster GEMM operations for sizes up to ~200 elements per
dimension (varies by data type and compute capability).
Key changes:
parameters, size validation, and device code generation
generation (get_jit_class_name, get_jit_op_str)
must invoke operator() with bounds checking at the tensor level
cuBLASDx operators with fixed block dimensions
pass-through thread execution model
Supported types: half, bfloat16, float, double, and their complex
variants. Size limits are architecture-dependent, ranging from 36-196
elements per dimension based on compute capability (SM 7.0 - SM 11.0).
Requires MATX_EN_MATHDX to be enabled at compile time.
Note that this is in early development. cuBLASDx has limitations that affect the MatX code base as a whole, such as dictating what the block size should be. This PR is for early adopter support and we will add more features over time.