Skip to content

[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938

Open
hxbai wants to merge 3 commits intoNVIDIA:mainfrom
hxbai:swiglu_offset
Open

[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938
hxbai wants to merge 3 commits intoNVIDIA:mainfrom
hxbai:swiglu_offset

Conversation

@hxbai
Copy link
Copy Markdown
Contributor

@hxbai hxbai commented Apr 28, 2026

Description

The previous ClampedSwiGLU follows GPT-OSS, which hard-coded the offset 1.0.
DeepSeek-V4 uses ClampedSwiGLU without alpha and offset.
This PR makes the offset of ClampedSwiGLU configurable to support DeepSeek-V4.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 28, 2026

Greptile Summary

This PR makes the glu_linear_offset of ClampedSwiGLU configurable (default 1.0, preserving legacy GPT-OSS behavior) so that DeepSeek-V4's offset-free variant is supported. The parameter is threaded consistently through every compute path: vanilla CUDA kernels (vectorized_pointwise.h), FP8/MXFP8 cast kernels, JAX XLA FFI structs, PyTorch pybind11 bindings, and the LayerNormMLP ONNX export path.

  • The cuDNN grouped-GEMM fusion guard in _common.py is correctly extended to block fusion when glu_linear_offset != 1.0, since the cuDNN kernel cannot express a configurable offset.
  • Both forward and backward gradient computations correctly incorporate the offset (the offset is a constant shift, so it drops out of the gate gradient but must remain in the activation gradient).
  • Tests for ClampedSwiGLU and ScaledClampedQGeGLU are parametrized over {0.0, 1.0} to cover both configurations.

Confidence Score: 5/5

The change is numerically correct, backward-compatible (default offset 1.0), and consistently applied across all CUDA, JAX, and PyTorch code paths.

All compute paths — vanilla CUDA, FP8, MXFP8, JAX FFI, and pybind11 — thread the new parameter correctly with a backward-compatible default. Forward and backward math is correct: the offset shifts the clamped linear component but is a constant w.r.t. the gate gradient, so only the activation gradient (not the gate gradient) carries it. The cuDNN fusion guard is properly tightened. No independent blocking issues were found.

transformer_engine/common/include/transformer_engine/activation.h — the public C API signature change was flagged in a prior review thread and warrants resolution before merge.

Important Files Changed

Filename Overview
transformer_engine/common/include/transformer_engine/activation.h Public C API functions nvte_clamped_swiglu / nvte_clamped_dswiglu have glu_linear_offset inserted before cudaStream_t, which is an ABI-breaking signature change already flagged in a prior review thread.
transformer_engine/common/util/math.h Adds glu_linear_offset = 1.0f default to ClampedSwiGLUParam; backward-compatible default preserves existing behavior.
transformer_engine/common/util/vectorized_pointwise.h Both forward and backward kernels now use p.glu_linear_offset instead of the hard-coded 1/1.0f; gradients are mathematically correct since the offset is a constant shift that drops out of the gate gradient.
transformer_engine/common/cast/fp8/gated_fp8.cuh FP8 forward path replaces hard-coded +1 with +p.glu_linear_offset; only the backward clamp-derivative guard is unchanged (correct, offset doesn't affect it).
transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh Both MXFP8 kernel variants consistently replace hard-coded 1.0f with p.glu_linear_offset.
transformer_engine/pytorch/ops/basic/swiglu.py Both ClampedSwiGLU and ScaledClampedQGeGLU correctly thread glu_linear_offset through forward and backward helpers with a default of 1.0.
transformer_engine/pytorch/ops/_common.py Fusion guard now also rejects non-default glu_linear_offset values, correctly preventing cuDNN grouped-GEMM fusion for configurations the fixed hardware kernel can't represent.
transformer_engine/jax/cpp_extensions/activation.py ClampedSwigluParams dataclass, its hash, to_ffi_lowering_dict, and the clamped_linear factory all correctly incorporate glu_linear_offset.
transformer_engine/jax/csrc/extensions.h XLA FFI struct decoding registration updated to include the new glu_linear_offset member.
transformer_engine/pytorch/csrc/extensions/pybind.cpp pybind11 bindings for clamped_swiglu and clamped_dswiglu expose the new glu_linear_offset keyword argument with default 1.0f, preserving Python API compatibility.
transformer_engine/pytorch/module/layernorm_mlp.py ONNX-export path correctly reads glu_linear_offset from activation_params; training path propagates through **act_params to the pybind11 extension with default 1.0.
tests/pytorch/test_fusible_ops.py Both test_clamped_swiglu and test_scaled_clamped_qgeglu are correctly parametrized over glu_linear_offset in {0.0, 1.0}, covering both the legacy GPT-OSS and new DeepSeek-V4 modes.

Sequence Diagram

sequenceDiagram
    participant User as User / Python
    participant PyBind as pybind11 (clamped_swiglu)
    participant CppExt as activation.cpp
    participant CUDA as CUDA Kernels
    note over User,CUDA: Forward pass
    User->>PyBind: clamped_swiglu(input, quantizer, limit, alpha, glu_linear_offset)
    PyBind->>CppExt: activation_helper → nvte_clamped_swiglu(... glu_linear_offset, stream)
    CppExt->>CUDA: ClampedSwiGLUParam{limit, alpha, glu_linear_offset}
    CUDA-->>CppExt: clamp(x_linear)+glu_linear_offset · silu(x_glu)
    CppExt-->>PyBind: output tensor
    PyBind-->>User: result
    note over User,CUDA: Backward pass
    User->>PyBind: clamped_dswiglu(grad, input, quantizer, limit, alpha, glu_linear_offset)
    PyBind->>CppExt: dactivation_helper → nvte_clamped_dswiglu(... glu_linear_offset, stream)
    CppExt->>CUDA: ClampedSwiGLUParam{limit, alpha, glu_linear_offset}
    CUDA-->>CppExt: d_act = dsilu(x_glu)·(clamp(x_linear)+offset)·grad, d_gate = silu(x_glu)·grad·∂clamp(x_linear)
    CppExt-->>PyBind: gradient tensors
    PyBind-->>User: gradients
Loading

Reviews (4): Last reviewed commit: "Merge branch 'main' into swiglu_offset" | Re-trigger Greptile

Comment on lines +339 to 341
* \param[in] glu_linear_offset Offset added to the linear component after clamping (default 1.0).
* \param[in] stream CUDA stream used for the operation.
*/
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Breaking public C API change

nvte_clamped_swiglu and nvte_clamped_dswiglu are public symbols declared in a versioned public header. Inserting glu_linear_offset before cudaStream_t is an ABI-breaking change: any external binary or shared library compiled against the old header will silently pass the stream pointer as the offset and a garbage value as the stream, leading to undefined behavior at runtime rather than a clean compile error if called via a pre-compiled library. This should be acknowledged as a breaking change in the PR checklist, and — if this library follows semantic versioning or a compatibility guarantee — a deprecation/transition path or version bump is needed.

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fused op for grouped MLP is hard-coded for GPT-OSS, so we should make sure not to fuse if glu_linear_offset != 1:

elif isinstance(window[1], ScaledClampedQGeGLU) and (
abs(window[1]._clamped.alpha - 1.702) > 0.001
or not _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu()
):

@timmoon10
Copy link
Copy Markdown
Collaborator

/te-ci

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
@hxbai hxbai marked this pull request as draft April 29, 2026 00:28
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
@hxbai hxbai marked this pull request as ready for review April 29, 2026 01:01

void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
cudaStream_t stream) {
float glu_linear_offset, cudaStream_t stream) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define new APIs named nvte_clamped_swiglu_v2 and nvte_clamped_dswiglu_v2
and deprecate this API here to not break backward compatibility?

Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants