-
Notifications
You must be signed in to change notification settings - Fork 613
[PyTorch] Add grouped linear op and experimental fusion for grouped MLP #2622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Test is too permissive since the test should still be failing. The weights are not properly interleaved yet. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Expose option for custom op fusions Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add tests for custom ops Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linter warnings and numerical test failures Signed-off-by: Tim Moon <tmoon@nvidia.com> * Tweak pattern matching logic with fixed window sizes Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use TF32 tols in fused op tests Signed-off-by: Tim Moon <tmoon@nvidia.com> * Review suggestion from @greptile-apps Signed-off-by: Tim Moon <tmoon@nvidia.com> * Backpropagate fixes from #2622 Signed-off-by: Tim Moon <tmoon@nvidia.com> --------- Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch L1 |
Greptile OverviewGreptile SummaryThis PR adds grouped linear operations and an experimental fused grouped MLP kernel for Mixture-of-Experts models. The implementation includes a new Critical Issues Found:
Changes:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant GroupedLinear1 as GroupedLinear (FC1)
participant ScaledSwiGLU
participant GroupedLinear2 as GroupedLinear (FC2)
participant FusedKernel as ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8
Note over User,FusedKernel: Forward Pass - Grouped MLP Block
User->>GroupedLinear1: input tensor + split_sizes
Note over GroupedLinear1: Split input by split_sizes
Note over GroupedLinear1: Quantize inputs (FP8/MXFP8)
Note over GroupedLinear1: Quantize weights (FP8/MXFP8)
GroupedLinear1->>GroupedLinear1: general_grouped_gemm()
Note over GroupedLinear1: Compute y = xW^T for each group
GroupedLinear1->>ScaledSwiGLU: FC1 output (interleaved gate/linear)
ScaledSwiGLU->>ScaledSwiGLU: Remove gate interleaving
ScaledSwiGLU->>ScaledSwiGLU: tex.swiglu(x)
Note over ScaledSwiGLU: SiLU(gate) * linear
ScaledSwiGLU->>ScaledSwiGLU: Multiply by scales
ScaledSwiGLU->>GroupedLinear2: Scaled SwiGLU output
GroupedLinear2->>GroupedLinear2: Split input by split_sizes
Note over GroupedLinear2: Quantize inputs (FP8/MXFP8)
GroupedLinear2->>GroupedLinear2: general_grouped_gemm()
Note over GroupedLinear2: Compute y = xW^T for each group
GroupedLinear2->>User: Final output
Note over FusedKernel: Alternative: Fused Path (SM100+, MXFP8)
User->>FusedKernel: input + split_sizes + scales
Note over FusedKernel: Pack MXFP8 tensors for kernel
FusedKernel->>FusedKernel: grouped_gemm_swiglu_kernel()
Note over FusedKernel: Fused FC1 + SwiGLU + scale
FusedKernel->>FusedKernel: Unpack MXFP8 outputs
FusedKernel->>FusedKernel: general_grouped_gemm() for FC2
FusedKernel->>User: Final output
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
| if not with_quantized_compute: | ||
| w = maybe_dequantize(w, dtype) | ||
| elif with_quantized_compute and not is_quantized_tensor(w): | ||
| quantizer = weight_quantizers[group_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
group_idx is undefined in this scope - loop variable is w, quantizer from the zip on line 415. Should use quantizer directly (already assigned) or add enumeration to the loop.
| quantizer = weight_quantizers[group_idx] | |
| # quantizer is already assigned from the zip, use it directly |
| if not self.is_supported(): | ||
| self.grouped_gemm_swiglu_kernel() # Try triggering import error | ||
| raise RuntimeError(f"{self.__class__.__name__} is not supported on this system.") | ||
| if fc1.in_features % 256 != 0 or fc1.in_features % 256 != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checking fc1.in_features % 256 != 0 twice instead of checking both in_features and out_features
| if fc1.in_features % 256 != 0 or fc1.in_features % 256 != 0: | |
| if fc1.in_features % 256 != 0 or fc1.out_features % 256 != 0: |
| f"Unsupported dims for FC1 (group_size={fc1.group_size}, " | ||
| f"in_features={fc1.in_features}, out_features={fc1.out_features})." | ||
| ) | ||
| if fc2.in_features % 256 != 0 or fc2.in_features % 256 != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checking fc2.in_features % 256 != 0 twice instead of checking both in_features and out_features
| if fc2.in_features % 256 != 0 or fc2.in_features % 256 != 0: | |
| if fc2.in_features % 256 != 0 or fc2.out_features % 256 != 0: |
Greptile OverviewGreptile SummaryThis PR adds grouped linear operations and SwiGLU activation variants to support Mixture-of-Experts (MoE) models with grouped MLP blocks. Key Changes:
Issues Found:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant GroupedLinear_FC1
participant ScaledSwiGLU
participant GroupedLinear_FC2
participant CuTeGEMM as CuTe Fused Kernel
User->>GroupedLinear_FC1: Input (batched tokens)
Note over GroupedLinear_FC1: Split input by group sizes
GroupedLinear_FC1->>GroupedLinear_FC1: Quantize inputs (FP8)
GroupedLinear_FC1->>GroupedLinear_FC1: Quantize weights (FP8)
alt MXFP8 + SM100 + no bias
GroupedLinear_FC1->>CuTeGEMM: Fused path
Note over CuTeGEMM: GroupedGEMM + SwiGLU + Scale
CuTeGEMM->>GroupedLinear_FC2: FC2 input (quantized)
else Standard path
GroupedLinear_FC1->>GroupedLinear_FC1: general_grouped_gemm
GroupedLinear_FC1->>ScaledSwiGLU: FC1 output
ScaledSwiGLU->>ScaledSwiGLU: Remove gate interleaving
ScaledSwiGLU->>ScaledSwiGLU: SwiGLU(x) = silu(x1) * x2
ScaledSwiGLU->>ScaledSwiGLU: Multiply by scales
ScaledSwiGLU->>GroupedLinear_FC2: Scaled activation
end
GroupedLinear_FC2->>GroupedLinear_FC2: Quantize inputs (FP8)
GroupedLinear_FC2->>GroupedLinear_FC2: general_grouped_gemm
GroupedLinear_FC2->>User: Output (batched tokens)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 4 comments
| dimension, applying a separate ``torch.nn.Linear`` to each split, | ||
| and concatenating along the first dimension. | ||
|
|
||
| Paramters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: "Paramters" should be "Parameters"
| Paramters | |
| Parameters |
| if not with_quantized_compute: | ||
| w = maybe_dequantize(w, dtype) | ||
| elif with_quantized_compute and not is_quantized_tensor(w): | ||
| quantizer = weight_quantizers[group_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
group_idx is undefined in this scope - loop uses w, quantizer from zip. Should use quantizer directly
| quantizer = weight_quantizers[group_idx] | |
| quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) |
| if fc1.in_features % 256 != 0 or fc1.in_features % 256 != 0: | ||
| raise ValueError( | ||
| f"Unsupported dims for FC1 (group_size={fc1.group_size}, " | ||
| f"in_features={fc1.in_features}, out_features={fc1.out_features})." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicate condition checks fc1.in_features % 256 twice instead of checking fc1.out_features
| if fc1.in_features % 256 != 0 or fc1.in_features % 256 != 0: | |
| raise ValueError( | |
| f"Unsupported dims for FC1 (group_size={fc1.group_size}, " | |
| f"in_features={fc1.in_features}, out_features={fc1.out_features})." | |
| ) | |
| if fc1.in_features % 256 != 0 or fc1.out_features % 256 != 0: | |
| raise ValueError( | |
| f"Unsupported dims for FC1 (group_size={fc1.group_size}, " | |
| f"in_features={fc1.in_features}, out_features={fc1.out_features})." | |
| ) |
| if fc2.in_features % 256 != 0 or fc2.in_features % 256 != 0: | ||
| raise ValueError( | ||
| f"Unsupported dims for FC2 (group_size={fc2.group_size}, " | ||
| f"in_features={fc2.in_features}, out_features={fc2.out_features})." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicate condition checks fc2.in_features % 256 twice instead of checking fc2.out_features
| if fc2.in_features % 256 != 0 or fc2.in_features % 256 != 0: | |
| raise ValueError( | |
| f"Unsupported dims for FC2 (group_size={fc2.group_size}, " | |
| f"in_features={fc2.in_features}, out_features={fc2.out_features})." | |
| ) | |
| if fc2.in_features % 256 != 0 or fc2.out_features % 256 != 0: | |
| raise ValueError( | |
| f"Unsupported dims for FC2 (group_size={fc2.group_size}, " | |
| f"in_features={fc2.in_features}, out_features={fc2.out_features})." | |
| ) |
Description
This PR adds a grouped linear op, which can be used in the grouped MLP block in Mixture-of-Experts models. It also adds an experimental fused operation for a grouped MLP block, using a CuTe DSL kernel that computes an MXFP8 grouped GEMM and SwiGLU.
Type of change
Changes
Checklist: