Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
dfdd382
Changed VERSION to 2.13.0.dev0
ptrendx Jan 20, 2026
27fc168
[Common] Enable determinism for cuDNN >= 9.18.1 on Blackwell (#2584)
cyanguwa Jan 20, 2026
fbb16f4
[Common] Tuned NVFP4 cast kernel (#2412)
Oleg-Goncharov Jan 21, 2026
36f4e45
Fixed the year to 2026 (#2611)
Oleg-Goncharov Jan 21, 2026
605786f
[pyTorch] CPU performance optimizations (#2439)
ptrendx Jan 21, 2026
8bf37f0
[JAX] Fix cb.CUDAOptions usage for Triton 3.6.0 (#2610)
jberchtold-nvidia Jan 22, 2026
3d46bf6
Permutation to always return group_size/tokens_per_expert (#2613)
tdophung Jan 22, 2026
0f0e229
[PyT] Update THD sink attention logic for cudnn >=9.18.0 (#2568)
cuichenx Jan 22, 2026
c6a92a4
Add support for SWA (left, right) with FusedAttention (#2477)
sudhakarsingh27 Jan 22, 2026
52ee5ea
Fix bugs in permutation custom partitioning (#2617)
tdophung Jan 23, 2026
a0a89a8
[Common] Disabled the tuned NVFP4 kernels (#2615)
Oleg-Goncharov Jan 23, 2026
7259276
[PyTorch] Support user-defined op fusions (#2597)
timmoon10 Jan 25, 2026
2dbfbc7
fix(examples): te_llama compatibility with transformers >= 4.57 (#2572)
sbhavani Jan 26, 2026
2104e4c
[JAX] Use "nyu-mll/glue" instead of "glue" for encoder datasets to fi…
jberchtold-nvidia Jan 27, 2026
f04b094
[PyTorch] ONNX test fix + export for FP8 attention (#2598)
pggPL Jan 28, 2026
b9f4013
[common] Add support for cuBLASLt GEMM for GroupedTensor (#2502)
pggPL Jan 28, 2026
f8cca8b
[Pytorch] Fix wheel test (#2635)
pggPL Jan 29, 2026
c3769cb
Fix minimum version of cublas for grouped gemm (#2631)
pggPL Jan 30, 2026
3ceb248
More detailed documentation for recipes (#2343)
pggPL Feb 2, 2026
94ba75d
Support building with headers from nvidia wheels (#2623)
vmarkovtsev Feb 3, 2026
29b84c1
[Common] Fix NVFP4 tuned-kernel numerics (#2639)
Oleg-Goncharov Feb 3, 2026
74faf7e
[PyTorch Debug] NVFP4 debug stats support (#2296)
pggPL Feb 3, 2026
59f6f38
[JAX] Update JAX container in readme (#2648)
jberchtold-nvidia Feb 4, 2026
71971e3
Fix exp2f_rcp to properly handle nan and 0xFE cases (#2647)
kainzhong Feb 6, 2026
7393947
[Common] MXFP8 kernel for grouped tensors (#2586)
Oleg-Goncharov Feb 6, 2026
dccf67e
[Common] Bucket batch size with higher granularity for THD (#2653)
cyanguwa Feb 7, 2026
c1a0c97
[PyTorch][Core][JAX] Expand troubleshooting docs (#2602)
jberchtold-nvidia Feb 9, 2026
b841243
[PyTorch Debug] Skip logging stats if unsupported (#2652)
pggPL Feb 9, 2026
2894e49
[Pytorch] Add get_backward_dw_params api for TE module (#2614)
Wohox Feb 9, 2026
b09ff7e
[pyTorch] Fix the compilation warnings (#2663)
ptrendx Feb 10, 2026
01ac7f8
[Pytorch] Make test script generate checkpoints if they don't exist (…
kainzhong Feb 10, 2026
8d15258
Fix Broken Quickstart Links (#2641)
faradawn Feb 11, 2026
8ebb47e
Fix on TE to support Mcore Vision Encoder CUDA Graph (#2657)
tomlifu Feb 11, 2026
ac81c85
[PyTorch] Python `GroupedTensor` (#2654)
ksivaman Feb 11, 2026
402ea54
[C] NVFP4 quantization for `GroupedTensor` (#2655)
ksivaman Feb 11, 2026
c4175fc
fix(build): Handle namespace packages for PyPI CUDA detection (#2580)
sbhavani Feb 12, 2026
93d51c8
[Common] Fuse pre-swizzling into grouped MXFP8 quantization kernel (#…
Oleg-Goncharov Feb 12, 2026
3774aa3
[PyTorch] Add ops for MoE grouped MLP (#2664)
timmoon10 Feb 12, 2026
33ca615
Add sigmoid GLU (#2656)
singleheart Feb 12, 2026
cd098e4
fix: correct FusedAdam copy-paste in FusedSGD error messages (#2675)
Mr-Neutr0n Feb 12, 2026
496620a
Get rid of nvshmem dependency for cuBLASMp integration (#2661)
vcherepanov-nv Feb 12, 2026
f844905
[PyTorch] Make grouped weights opt-in (#2678)
ksivaman Feb 13, 2026
5d112e3
[JAX] TE Permutation integration to Maxtext (#2672)
tdophung Feb 13, 2026
fa68781
Fix `build_tools` missing from sdist causing `uv` cached installs to …
hemildesai Feb 17, 2026
7e48fa1
[JAX] Debugging inspect utility (#2651)
jberchtold-nvidia Feb 17, 2026
f122b07
Changed VERSION to 2.14.0.dev0
ptrendx Feb 18, 2026
2d0d276
[PyT] Plumbing correct bias dims from TE to cudnn, while adding suppo…
KshitijLakhani Feb 18, 2026
63defea
Update cudnn-frontend to v1.18 (#2689)
cyanguwa Feb 20, 2026
e583222
[PyTorch] Documentation for op fuser API (#2447)
timmoon10 Feb 20, 2026
57b5b60
Fix race condition in RHT amax kernels (#2695)
ksivaman Feb 21, 2026
e8f7c5a
Add and verify support for `deterministic` fp8 dpa/mha on SM100 (#2621)
sudhakarsingh27 Feb 24, 2026
39b6dd9
[PyTorch Debug] Custom feature tutorial. (#2216)
pggPL Feb 24, 2026
7d1de30
Fix vermin pre-commit hook (#2699)
pstjohn Feb 24, 2026
459e7cf
[Common][PyTorch] Fuse scaling and unscaling of bf16 momentums into k…
yaox12 Feb 24, 2026
9eb982e
Fix incorrect MNNVL fabric check (#2626)
nvcastet Feb 24, 2026
f8b271f
[JAX] Fix FSDP when FSDP+EP is active (#2649)
jberchtold-nvidia Feb 24, 2026
7222d87
[PyTorch Debug] Support precision debug tools for fp8 model parameter…
pggPL Feb 25, 2026
df0ef6e
remove deprecated qkv/kv_packed apis (#2696)
sudhakarsingh27 Feb 25, 2026
842b770
[Common] Remove volatile keyword in fused router kernel utils (#2683)
denera Feb 26, 2026
ad56283
[CI] Cancel on concurrency (#2708)
yaox12 Feb 27, 2026
b345941
[PyTorch] `GroupedTensor` integration (#2600)
ksivaman Feb 27, 2026
a9a9b3a
[Common][PyTorch] Enhance the fused router and unify the precision (#…
yaox12 Feb 27, 2026
3ecb5bf
[PyTorch] Fix L3 FA tests (#2709)
cyanguwa Feb 28, 2026
f508e66
[PyTorch] Remove `is_first_microbatch` setting after cudagraph warmup…
buptzyb Mar 2, 2026
537f134
[Common][PyTorch] Fix normalization for `fused_score_for_moe_aux_loss…
Autumn1998 Mar 2, 2026
bba7bf6
[PyTorch] Support cuda graph capturing offloading module (#2435)
lhb8125 Mar 2, 2026
3275e1a
[JAX] CGEMM with Shardy (#2714)
phu0ngng Mar 2, 2026
9dac78e
CPU Overhead Optimizations (#2559)
vthumbe1503 Mar 3, 2026
c68ec31
Add fast_set_attr to modules not inheriting from base.py (#2724)
vthumbe1503 Mar 3, 2026
39d249b
[JAX] Remove GSPMD tests + adding guards and warning msg for GSPMD ru…
phu0ngng Mar 3, 2026
a3bc040
NVFP4 primary weight support (#2691)
WanZzzzzz Mar 3, 2026
bf3201a
[PyTorch] Support single parameter for `GroupedLinear` (#2731)
ksivaman Mar 4, 2026
00ba0b4
pass params_dtype to qk_norm creation (#2718)
pstjohn Mar 4, 2026
505b896
[JAX] GSPMD Deprecation Warning - Only trigger when the primitive is …
phu0ngng Mar 4, 2026
139c863
Add fused_adam, quantized_model_init, and fsdp2 example (#2698)
pstjohn Mar 4, 2026
56c2fa6
[JAX] Support calling MOE router kernels from JAX side (#2711)
tdophung Mar 4, 2026
d2e4755
[PyTorch] Skip `test_nvfp4_partial_cast_matches_full` test when NVFP4…
ksivaman Mar 5, 2026
145e88c
Add multi-precision training support to FSDP script (#2662)
aagallo Mar 5, 2026
d9152b0
[PyTorch] Support `GroupedTensor` torch ops for DDP and distributed o…
ksivaman Mar 5, 2026
d226ce2
[JAX] Integrate BF16 Grouped GEMM with on-device group sizes (#2680)
jberchtold-nvidia Mar 5, 2026
d40b9de
WAR sort_chunks_by_index intermittent failures in L0 JAX unitttest pa…
tdophung Mar 5, 2026
5fd5c35
Fix FP8 block scaling with sequence parallel (#2637)
cuichenx Mar 8, 2026
ab9d60e
[PyTorch] Zero-initialize learnable softmax_offset in DotProductAtten…
fjosw Mar 8, 2026
e9ea352
docs: update cuDNN sliding window attention support (#2624)
sbhavani Mar 8, 2026
6638fef
[JAX] GEMM tex and FFI cleanup (#2739)
phu0ngng Mar 8, 2026
34a6c0a
Fix Flash Attention 3 API compatibility for window size parameters (#…
jhvmhg Mar 9, 2026
6e0085a
[Common] Remove redundant grad_logits zero-initialization in fused ro…
roycho96 Mar 9, 2026
f64941a
Enable dequantization from MXFP8 tensor with only columnwise data (#2…
ptrendx Mar 10, 2026
e6d97ff
[PyTorch] Fix cross_entropy_forward stride guard for non-contiguous i…
Bias92 Mar 10, 2026
7c2aa2c
[Common] MOE Split dBias (#2674)
Oleg-Goncharov Mar 10, 2026
3846bf7
Fix deploy nightly docs issue (#2636)
pggPL Mar 10, 2026
d32f9e4
[JAX] Fix get_seqlens_and_offsets() to accept vmapped seg ids and non…
KshitijLakhani Mar 11, 2026
61d5865
[NVFP4][MOE] Add unfused quantization fallback when input shape is no…
zhongbozhu Mar 11, 2026
7545d8c
[PyTorch debug] Fix issue with tp_group=None (#2733)
pggPL Mar 11, 2026
107f558
Documentation for cpu offloading (#2520)
pggPL Mar 11, 2026
d5ce416
Add guard at lowest JAX version that still supports triton kernel cal…
tdophung Mar 11, 2026
f6001c4
Support configurable number of philox rounds for stochastic rounding …
ksivaman Mar 11, 2026
61f9594
[All] Added better error messages (#2705)
ptrendx Mar 11, 2026
c021e7e
[PyTorch] Fix fuser so it releases tensors properly (#2750)
kainzhong Mar 11, 2026
7fb10d3
[PyTorch] Add dtype information to QuantizedTensorStorage class (#2676)
ptrendx Mar 11, 2026
4c5b1a2
[JAX] Change dtype of intermediate result aval of fused_topk_and_scor…
tdophung Mar 11, 2026
06a23e3
Initial commit to pass scale as Tensor for multi_tensor_scale op (#2594)
vasunvidia Mar 12, 2026
ef703e5
[Core] MXFP8 grouped GEMM + tensor-scaled FP8 fixes (#2748)
jberchtold-nvidia Mar 12, 2026
67898a7
Cherry pick "Adds dst.dtype information in copy_ method of quantized …
ptrendx Mar 12, 2026
134304e
Fused kernel for calculating offsets from first dim splits (#2755)
ksivaman Mar 12, 2026
a5d7464
Added new users to CI (#2756)
ptrendx Mar 12, 2026
6a68c73
[PyTorch] Error out if constructing `LayerNormLinear` with row tensor…
timmoon10 Mar 12, 2026
14c29da
[JAX] Collective GEMM with FP8 and MXFP8 support (#2740)
phu0ngng Mar 13, 2026
fcceeb9
[Pytorch] Add QuantizedTensor support in FusedAdam.step for MXFP8Bloc…
jomitchellnv Mar 13, 2026
306e853
add .claude to gitignore (#2762)
pstjohn Mar 13, 2026
b7214fd
Fix for async dcp checkpointing with Float8Tensors (#2721)
pstjohn Mar 15, 2026
708d7c1
Pytorch binding for cublas grouped gemm + Grouped Bias Support + Grou…
vthumbe1503 Mar 16, 2026
2156e61
Merge commit '708d7c160ad6b2bf44c9c597083d4cbb4860f068' from upstream
ipanfilo Apr 14, 2026
11ab82a
Resovle merging errors, fixed build and load, restore codepaths missi…
ipanfilo Apr 21, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
4 changes: 4 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ name: 'Build'
on:
pull_request:
workflow_dispatch:
concurrency:
# Group by workflow name + PR number (for PRs) or ref (for branch/tag pushes)
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
core:
name: 'Core'
Expand Down
7 changes: 4 additions & 3 deletions .github/workflows/deploy_nightly_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ name: Deploy nightly docs
on:
push:
branches: [ "main" ]
workflow_dispatch:
jobs:
build:
uses: ./.github/workflows/docs.yml
Expand All @@ -21,9 +22,8 @@ jobs:
name: "te_docs"
path: "html"
- name: Prepare for pages
uses: actions/upload-pages-artifact@v1.0.7
uses: actions/upload-pages-artifact@v3
with:
name: github-pages
path: "html"
deploy:
needs: prepare
Expand All @@ -36,4 +36,5 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Deploy
uses: actions/deploy-pages@v2.0.0
id: deployment
uses: actions/deploy-pages@v4
4 changes: 4 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ on:
pull_request:
workflow_dispatch:
workflow_call:
concurrency:
# Group by workflow name + PR number (for PRs) or ref (for branch/tag pushes)
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
build_docs:
name: 'Build'
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ name: 'Lint'
on:
pull_request:
workflow_dispatch:
concurrency:
# Group by workflow name + PR number (for PRs) or ref (for branch/tag pushes)
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
pytorch_cpplint:
name: 'PyTorch C++'
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/trigger-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ jobs:
|| github.actor == 'vthumbe1503'
|| github.actor == 'shengfangd'
|| github.actor == 'kainzhong'
|| github.actor == 'cspades'
|| github.actor == 'jomitchellnv'
)
steps:
- name: Check if comment is issued by authorized person
Expand Down
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,8 @@ artifacts/
**/times.csv
transformer_engine/build_info.txt
transformer_engine/common/util/hip_nvml.*
*.DS_Store
.DS_Store
.rsync-filter
.codex/
.cline_storage/
.claude/
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ repos:
rev: c75aca72f4e85c6e47252139e8695f1c8b5f9ae3
hooks:
- id: vermin
args: ['-t=3.10', '--violations']
args: ['-t=3.10-', '--violations']
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 221 files
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
recursive-include transformer_engine/common/include *.*
recursive-include build_tools *.py *.txt
46 changes: 42 additions & 4 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ Flax
for _ in range(10):
loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)

For a more comprehensive tutorial, check out our `Quickstart Notebook <https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/quickstart.ipynb>`_.
For a more comprehensive tutorial, check out our `Getting Started Guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/getting_started.html>`_.

.. overview-end-marker-do-not-remove

Expand Down Expand Up @@ -496,15 +496,22 @@ For example to use the NGC PyTorch container interactively,

.. code-block:: bash

docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.08-py3
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:26.01-py3

For example to use the NGC JAX container interactively,

.. code-block:: bash

docker run --gpus all -it --rm nvcr.io/nvidia/jax:25.08-py3
docker run --gpus all -it --rm nvcr.io/nvidia/jax:26.01-py3

Where 25.08 (corresponding to August 2025 release) is the container version.
Where 26.01 (corresponding to January 2026 release) is the container version.

We recommend updating to the latest NGC container available here:

* https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch
* https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax

If you run any examples, please ensure you are using a matching version of TransformerEngine. TransformerEngine is pre-built and packaged inside the containers with examples available at ``/opt/transformerengine`` or ``/opt/transformer-engine``. If you would like to use examples from TE main branch and are running into import errors, please try the latest pip package or building from source, although NGC containers are recommended for ease-of-use for most users.

**Benefits of using NGC containers:**

Expand Down Expand Up @@ -628,6 +635,37 @@ Troubleshooting
cd transformer_engine
pip install -v -v -v --no-build-isolation .

**Problems using UV or Virtual Environments:**

1. **Import Error:**

* **Symptoms:** Cannot import ``transformer_engine``
* **Solution:** Ensure your UV environment is active and that you have used ``uv pip install --no-build-isolation <te_pypi_package_or_wheel_or_source_dir>`` instead of a regular pip install to your system environment.

2. **cuDNN Sublibrary Loading Failed:**

* **Symptoms:** Errors at runtime with ``CUDNN_STATUS_SUBLIBRARY_LOADING_FAILED``
* **Solution:** This can occur when TE is built against the container's system installation of cuDNN, but pip packages inside the virtual environment pull in pip packages for ``nvidia-cudnn-cu12/cu13``. To resolve this, when building TE from source please specify the following environment variables to point to the cuDNN in your virtual environment.


.. code-block:: bash

export CUDNN_PATH=$(pwd)/.venv/lib/python3.12/site-packages/nvidia/cudnn
export CUDNN_HOME=$CUDNN_PATH
export LD_LIBRARY_PATH=$CUDNN_PATH/lib:$LD_LIBRARY_PATH

3. **Building Wheels:**

* **Symptoms:** Regular TE installs work correctly but UV wheel builds fail at runtime.
* **Solution:** Ensure that ``uv build --wheel --no-build-isolation -v`` is used during the wheel build as well as the pip installation of the wheel. Use ``-v`` for verbose output to verify that TE is not pulling in a mismatching version of PyTorch or JAX that differs from the UV environment's version.

**JAX-specific Common Issues and Solutions:**

1. **FFI Issues:**

* **Symptoms:** ``No registered implementation for custom call to <some_te_ffi> for platform CUDA``
* **Solution:** Ensure ``--no-build-isolation`` is used during installation. If pre-building wheels, ensure that the wheel is both built and installed with ``--no-build-isolation``. See "Problems using UV or Virtual Environments" above if using UV.

.. troubleshooting-end-marker-do-not-remove

Breaking Changes
Expand Down
2 changes: 1 addition & 1 deletion build_tools/VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.12.0.dev0
2.14.0.dev0
1 change: 1 addition & 0 deletions build_tools/hipify/custom_map.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"__nv_fp8_e4m3" : "te_hip_fp8_e4m3",
"cuda::getCurrentCUDAStream" : "hip::getCurrentHIPStreamMasqueradingAsCUDA",
"at::cuda::CUDAGuard" : "at::hip::HIPGuardMasqueradingAsCUDA",
"c10::cuda::" : "c10::hip::",
"__nv_fp4_e2m1" : "__hip_fp4_e2m1",
"__nv_fp4x2_e2m1" : "__hip_fp4x2_e2m1",
"__nv_fp4x4_e2m1" : "__hip_fp4x4_e2m1",
Expand Down
7 changes: 1 addition & 6 deletions build_tools/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,7 @@ def xla_path() -> str:
Throws FileNotFoundError if XLA source is not found."""

try:
import jax
from packaging import version
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi
else:
from jax.extend import ffi
from jax import ffi
except ImportError:
if os.getenv("XLA_HOME"):
xla_home = Path(os.getenv("XLA_HOME"))
Expand Down
8 changes: 6 additions & 2 deletions build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,10 @@ def nvcc_path() -> Tuple[str, str]:
def get_cuda_include_dirs() -> Tuple[str, str]:
"""Returns the CUDA header directory."""

force_wheels = bool(int(os.getenv("NVTE_BUILD_USE_NVIDIA_WHEELS", "0")))
# If cuda is installed via toolkit, all necessary headers
# are bundled inside the top level cuda directory.
if cuda_toolkit_include_path() is not None:
if not force_wheels and cuda_toolkit_include_path() is not None:
return [cuda_toolkit_include_path()]

# Use pip wheels to include all headers.
Expand All @@ -317,7 +318,10 @@ def get_cuda_include_dirs() -> Tuple[str, str]:
except ModuleNotFoundError as e:
raise RuntimeError("CUDA not found.")

cuda_root = Path(nvidia.__file__).parent
if nvidia.__file__ is not None:
cuda_root = Path(nvidia.__file__).parent
else:
cuda_root = Path(nvidia.__path__[0]) # namespace
return [
subdir / "include"
for subdir in cuda_root.iterdir()
Expand Down
9 changes: 6 additions & 3 deletions ci/pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ TEST_DIR=${TE_PATH}tests/pytorch
#: ${TEST_WORKERS:=4}

install_prerequisites() {
pip install 'numpy>=1.22.4' pandas
pip install 'numpy>=1.22.4' pandas safetensors
rc=$?
if [ $rc -ne 0 ]; then
script_error "Failed to install test prerequisites"
Expand Down Expand Up @@ -100,8 +100,11 @@ run_test_config_mgpu(){
run_default_fa 2 distributed/test_numerics.py
run_default_fa 1 distributed/test_torch_fsdp2.py
run_default_fa 2 distributed/test_torch_fsdp2_fp8.py
run_default_fa_lbl "flash" 3 attention/test_attention_with_cp.py -k "with_flash"
run_default_fa_lbl "fused" 2 attention/test_attention_with_cp.py -k "with_fused"
if [ $_fus_attn = ck ]; then
run 2 attention/test_attention_with_cp.py -k "with_fused"
elif [ $_fus_attn = flash ]; then
run 3 attention/test_attention_with_cp.py -k "with_flash"
fi
}

run_benchmark() {
Expand Down
134 changes: 134 additions & 0 deletions docs/_static/css/diagram-colors.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/* Diagram color definitions for Transformer Engine documentation */

/* High precision (BF16/FP16) elements */
.hp {
fill: #ede7f6;
stroke: #673ab7;
stroke-width: 2;
}

/* FP8 precision elements */
.fp8 {
fill: #fff8e1;
stroke: #ffa726;
stroke-width: 2;
}

/* GEMM/computation operations */
.gemm {
fill: #ffe0b2;
stroke: #fb8c00;
stroke-width: 2.5;
}

/* Quantization operations */
.quantize {
fill: #e8f5e9;
stroke: #66bb6a;
stroke-width: 2;
}

/* Amax computation operations */
.amax {
fill: #e1f5fe;
stroke: #039be5;
stroke-width: 2;
}

/* Text styles */
.text {
font-family: 'Segoe UI', Arial, sans-serif;
font-size: 14px;
text-anchor: middle;
fill: #212121;
}

.small-text {
font-family: 'Segoe UI', Arial, sans-serif;
font-size: 14px;
text-anchor: middle;
fill: #757575;
}

.label {
font-family: 'Segoe UI', Arial, sans-serif;
font-size: 14px;
text-anchor: middle;
fill: #424242;
}

.title {
font-family: 'Segoe UI', Arial, sans-serif;
font-size: 18px;
font-weight: 600;
text-anchor: middle;
fill: #212121;
}

.section-title {
font-family: 'Segoe UI', Arial, sans-serif;
font-size: 15px;
font-weight: 600;
text-anchor: middle;
}

/* Arrows */
/* Note: marker-end references #arrowhead marker which must be defined in each SVG's <defs> section */
.arrow {
stroke: #616161;
stroke-width: 2;
fill: none;
marker-end: url(#arrowhead);
}

/* Additional box and element styles */
.box-blue {
fill: #e3f2fd;
stroke: #1976d2;
stroke-width: 2;
}

.box-orange {
fill: #fff3e0;
stroke: #f57c00;
stroke-width: 2;
}

.box-green {
fill: #c8e6c9;
stroke: #388e3c;
stroke-width: 2;
}

.box-dashed {
stroke-dasharray: 5,5;
}

/* LayerNorm specific */
.layernorm {
fill: #b3e5fc;
stroke: #0277bd;
stroke-width: 2.5;
}

/* Fused layers */
.fused {
fill: #b2dfdb;
stroke: #00695c;
stroke-width: 3;
}

/* Generic computation blocks */
.computation {
fill: #f5f5f5;
stroke: #757575;
stroke-width: 2;
}

/* FP32 precision (alternative red) */
.fp32 {
fill: #ffcdd2;
stroke: #d32f2f;
stroke-width: 2.5;
}

Loading