-
Notifications
You must be signed in to change notification settings - Fork 613
[DRAFT] [JAX] Grouped GEMM #2619
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
- Add FP8 scale_inv pointer handling in nvte_grouped_gemm for proper FP8 GEMM - Fix random padding in tests to ensure 16-byte alignment for all dtypes - Reorder GroupedGemmSetupWorkspace members for natural alignment - Remove debug prints Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
- Remove unused alignment parameter from GroupedGemmSetupWorkspace::from_buffers - Simplify select_grouped_operand by removing dead code branches - Add GroupedOperandSelection.tensor field to avoid passing tensor separately - Extract set_fp8_scale_pointers and init_matrix_layouts helpers - Add safety check for FP8 on Hopper column-wise fallback - Support NULL C tensor when beta=0 (uses D as placeholder) - Remove unused get_scale_inv() from test - Add use_null_c test parameter and test case - Fix documentation: alpha/beta are single element tensors only Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
- Change alpha/beta from single values to per-matrix arrays - Validate alpha/beta have exactly num_tensors elements - Update kernel to index alpha_ptr[idx] and beta_ptr[idx] - Move alpha/beta validation to validate_grouped_gemm_inputs - Update tests to use per-matrix alpha/beta arrays - Update documentation Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
… single-stream for multi tensor quantize)
Greptile SummaryThis PR implements grouped GEMM support for JAX, enabling efficient batched matrix multiplications with varying shapes per group - a key optimization for Mixture-of-Experts (MoE) models with per-expert quantization. Key Changes
Notes
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as JAX User Code
participant Einsum as einsum.py
participant Dense as dense.py
participant GemmPrim as gemm.py (Primitive)
participant FFI as gemm.cpp (FFI)
participant Kernel as cublaslt_grouped_gemm.cu
participant cuBLAS as cuBLASLt
User->>Einsum: einsum("EBCM,EMH->EBCH", x, w, quantizer_sets, quantizer_dim='E')
Einsum->>Einsum: Parse equation & validate quantizer_dim
Einsum->>Einsum: Stack quantizer_sets into pytree
Einsum->>Dense: vmap(dense_with_quantizer) over expert dimension
loop For each expert
Dense->>Dense: Quantize operands (if needed)
Dense->>GemmPrim: grouped_gemm.bind(lhs, rhs, ...)
GemmPrim->>GemmPrim: Shape inference & sharding rules
GemmPrim->>FFI: GroupedGemmFFI (XLA custom call)
FFI->>FFI: Validate inputs & allocate workspaces
FFI->>Kernel: nvte_grouped_gemm(A, B, C, D, alpha, beta, ...)
Kernel->>Kernel: setup_grouped_gemm_kernel (compute pointers/dims)
Kernel->>Kernel: Select operand storage (row/column-wise)
Kernel->>cuBLAS: cublasLtMatmul (grouped API)
cuBLAS-->>Kernel: Compute D = alpha*A@B + beta*C
Kernel-->>FFI: Return output
FFI-->>GemmPrim: Return result
GemmPrim-->>Dense: Return expert output
end
Dense-->>Einsum: Stack expert outputs
Einsum-->>User: Return final result
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
29 files reviewed, 1 comment
| f"Got batched_args={[arg.shape for arg, bdim in zip(batched_args, batch_dims) if bdim is not None]}." | ||
| ) | ||
| assert batch_dim is not None and batch_size is not None, "Invalid batching config!" | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: debug print statement left in production code
| # print(f"[{cls.__name__}] Batching with size {batch_size}") |
This reverts commit 0e16fef.
0259184 to
1f6283f
Compare
for more information, see https://pre-commit.ci
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: