[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938
[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938hxbai wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR makes the
Confidence Score: 5/5The change is numerically correct, backward-compatible (default offset 1.0), and consistently applied across all CUDA, JAX, and PyTorch code paths. All compute paths — vanilla CUDA, FP8, MXFP8, JAX FFI, and pybind11 — thread the new parameter correctly with a backward-compatible default. Forward and backward math is correct: the offset shifts the clamped linear component but is a constant w.r.t. the gate gradient, so only the activation gradient (not the gate gradient) carries it. The cuDNN fusion guard is properly tightened. No independent blocking issues were found. transformer_engine/common/include/transformer_engine/activation.h — the public C API signature change was flagged in a prior review thread and warrants resolution before merge. Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User / Python
participant PyBind as pybind11 (clamped_swiglu)
participant CppExt as activation.cpp
participant CUDA as CUDA Kernels
note over User,CUDA: Forward pass
User->>PyBind: clamped_swiglu(input, quantizer, limit, alpha, glu_linear_offset)
PyBind->>CppExt: activation_helper → nvte_clamped_swiglu(... glu_linear_offset, stream)
CppExt->>CUDA: ClampedSwiGLUParam{limit, alpha, glu_linear_offset}
CUDA-->>CppExt: clamp(x_linear)+glu_linear_offset · silu(x_glu)
CppExt-->>PyBind: output tensor
PyBind-->>User: result
note over User,CUDA: Backward pass
User->>PyBind: clamped_dswiglu(grad, input, quantizer, limit, alpha, glu_linear_offset)
PyBind->>CppExt: dactivation_helper → nvte_clamped_dswiglu(... glu_linear_offset, stream)
CppExt->>CUDA: ClampedSwiGLUParam{limit, alpha, glu_linear_offset}
CUDA-->>CppExt: d_act = dsilu(x_glu)·(clamp(x_linear)+offset)·grad, d_gate = silu(x_glu)·grad·∂clamp(x_linear)
CppExt-->>PyBind: gradient tensors
PyBind-->>User: gradients
Reviews (4): Last reviewed commit: "Merge branch 'main' into swiglu_offset" | Re-trigger Greptile |
| * \param[in] glu_linear_offset Offset added to the linear component after clamping (default 1.0). | ||
| * \param[in] stream CUDA stream used for the operation. | ||
| */ |
There was a problem hiding this comment.
nvte_clamped_swiglu and nvte_clamped_dswiglu are public symbols declared in a versioned public header. Inserting glu_linear_offset before cudaStream_t is an ABI-breaking change: any external binary or shared library compiled against the old header will silently pass the stream pointer as the offset and a garbage value as the stream, leading to undefined behavior at runtime rather than a clean compile error if called via a pre-compiled library. This should be acknowledged as a breaking change in the PR checklist, and — if this library follows semantic versioning or a compatibility guarantee — a deprecation/transition path or version bump is needed.
timmoon10
left a comment
There was a problem hiding this comment.
The fused op for grouped MLP is hard-coded for GPT-OSS, so we should make sure not to fuse if glu_linear_offset != 1:
TransformerEngine/transformer_engine/pytorch/ops/_common.py
Lines 180 to 183 in df0025b
|
/te-ci |
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
|
|
||
| void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, | ||
| cudaStream_t stream) { | ||
| float glu_linear_offset, cudaStream_t stream) { |
There was a problem hiding this comment.
Can we define new APIs named nvte_clamped_swiglu_v2 and nvte_clamped_dswiglu_v2
and deprecate this API here to not break backward compatibility?
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Description
The previous ClampedSwiGLU follows GPT-OSS, which hard-coded the offset 1.0.
DeepSeek-V4 uses ClampedSwiGLU without alpha and offset.
This PR makes the offset of ClampedSwiGLU configurable to support DeepSeek-V4.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: