Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
177b2ec
cuBlasMp backend logic added to TE/common with connections to framewo…
denera Dec 2, 2025
7d46b0b
added use_cublasmp flags to CollectiveGemm bootstrapping to avoid UB …
denera Dec 2, 2025
6d4a141
added cuBLASMp backend option to JAX unit tests for CollectiveGEMM
denera Dec 16, 2025
35d0f19
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 16, 2025
dd8eaf3
added pytorch unit tests for comm+GEMM overlap with cuBLASMp backend
denera Dec 16, 2025
d79bf21
greptile fixes
denera Dec 17, 2025
ee517d3
linting
denera Dec 17, 2025
51b64fb
function argument call order fixes
denera Dec 17, 2025
9be771c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 17, 2025
4cec043
JAX collective GEMM modified to inherit cublasmp usage from global bo…
denera Jan 16, 2026
898cf30
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Jan 16, 2026
422a654
typos and style fixes
pre-commit-ci[bot] Jan 16, 2026
6e42235
documentation and build fixes
denera Jan 27, 2026
626dd1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 27, 2026
d44cfc4
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Mar 13, 2026
e341a8b
fixed default SM margin option and JAX cgemm test runner cleanup
denera Mar 13, 2026
6942d20
cublasmp running with TE/PyTorch
denera Mar 16, 2026
bef5c7e
cublasmp working with TE/JAX
denera Mar 16, 2026
81d6383
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Mar 16, 2026
6c6cc4d
cublasmp working with TE/JAX (JAX container is missing cuBLASMp insta…
denera Mar 16, 2026
9ed2adf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 16, 2026
ca913b9
added arch suffixes for CUBLASMP lib lookup in CMAKE
denera Mar 16, 2026
c55626d
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera Mar 16, 2026
f863ba8
fixed TE/JAX collective gemm test runner
denera Mar 16, 2026
5a8c7ae
TE/JAX CGEMM test runner script fix
denera Mar 17, 2026
5b9df92
fixed the cublasmp option in the pytest runners
denera Mar 17, 2026
775df95
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Mar 20, 2026
441472a
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Apr 7, 2026
3df11fc
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Apr 17, 2026
58f1e68
cuBLASMp passing tests with TE/PyTorch
denera Apr 21, 2026
f05f849
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Apr 21, 2026
f95f229
updated cuBLASMp C++ tests to also test local chunks instead of globa…
denera Apr 21, 2026
f84e8f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 21, 2026
c67c183
cuBLASmp C++ tests switched to NCCL comms for reference results, now …
denera Apr 22, 2026
e9c79a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2026
1b8fb1e
[JAX] Fix cuBLASMp collective GEMM tests and document XLA command buf…
denera Apr 24, 2026
caa741e
changed cuBLASMp call sizing to use flat first/last dims
denera May 1, 2026
9cca8a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import os
from pathlib import Path
from importlib import metadata

import setuptools

Expand Down Expand Up @@ -88,6 +89,15 @@ def setup_pytorch_extension(
libraries.append("nvshmem_host")
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")

if bool(int(os.getenv("NVTE_WITH_CUBLASMP", 0))):
# Creating a cuBlasMp context requires direct access to the underlying NCCL
# communicator in a tensor-parallel process group. The header for ProcessGroupNCCL
# needs this CPP directive to be included properly.
cxx_flags.append("-DNVTE_WITH_CUBLASMP")
torch_lib_path = metadata.distribution("torch").locate_file("torch/lib")
library_dirs.append(torch_lib_path)
libraries.append("torch_cuda")

# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
Expand Down
36 changes: 34 additions & 2 deletions examples/jax/collective_gemm/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
"""Shared functions for the collective GEMM tests"""

import argparse
import glob
import os

import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import mesh_utils
from jax.experimental.multihost_utils import sync_global_devices

from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap

Expand Down Expand Up @@ -56,9 +59,9 @@ def assert_allclose(actual, desired, rtol=None, atol=None, dtype=None, **kwargs)
tols["atol"] = atol

if not isinstance(actual, float):
actual = actual.astype(jnp.float32)
actual = np.asarray(actual, dtype=np.float32)
if not isinstance(desired, float):
desired = desired.astype(jnp.float32)
desired = np.asarray(desired, dtype=np.float32)

np.testing.assert_allclose(actual, desired, **tols, **kwargs)

Expand Down Expand Up @@ -96,6 +99,14 @@ def _initialize_distributed(args):

assert args.num_devices_per_process == 1, "Only single process single GPU is supported!"

# Collective GEMM with communication overlap (Userbuffers or cuBLASMp) uses internal
# CUDA streams for overlapping NCCL collectives with compute. XLA command buffers
# (CUDA graph capture) cannot record work that spans multiple streams, so we must
# disable them when running collective GEMM with overlap.
xla_flags = os.environ.get("XLA_FLAGS", "")
if "--xla_gpu_enable_command_buffer" not in xla_flags:
os.environ["XLA_FLAGS"] = xla_flags + " --xla_gpu_enable_command_buffer="

print(
f"Initializing JAX distributed with coordinator={args.coordinator_address}, "
f"num_processes={args.num_processes}, process_id={args.process_id}"
Expand All @@ -118,6 +129,20 @@ def _initialize_distributed(args):
devices_per_process = 1
num_total_devices = args.num_processes

# Remove stale NCCL unique ID files from previous (possibly crashed) runs.
# These files are used for one-time coordination during bootstrap; stale files
# cause non-leader processes to read an old unique ID, breaking NCCL init.
# Only process 0 performs the cleanup; a global barrier ensures all processes
# wait for the cleanup to complete before any TP leader writes a fresh file.
nccl_base_path = os.environ.get("NVTE_JAX_NCCL_FILE_PATH", "/tmp")
if args.process_id == 0:
for f in glob.glob(os.path.join(nccl_base_path, "nccl_*_unique_id_*.bin")):
try:
os.remove(f)
except OSError:
pass
sync_global_devices("nccl_id_cleanup")

print(
f"Initializing CGEMM communicator with num_total_devices={num_total_devices},"
f" devices_per_process={devices_per_process}, process_id={args.process_id}"
Expand All @@ -128,6 +153,7 @@ def _initialize_distributed(args):
num_devices_per_process=devices_per_process,
process_id=args.process_id,
tensor_parallel_size=args.tensor_parallel_size,
use_cublasmp=args.use_cublasmp,
)


Expand Down Expand Up @@ -224,5 +250,11 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para
parser.add_argument(
"--enable-result-check", action="store_true", default=True, help="Enable result checking"
)
parser.add_argument(
"--use-cublasmp",
action="store_true",
default=False,
help="Use the cuBLASMp backend for overlapping collective operations with GEMM computation",
)

return parser
11 changes: 11 additions & 0 deletions examples/jax/collective_gemm/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,29 @@
"""config for collective_gemm tests"""
import pytest

import transformer_engine.jax # noqa: F401 - must load libtransformer_engine.so before transformer_engine_jax
from transformer_engine_jax import nvte_built_with_cublasmp


def pytest_addoption(parser):
"""Pytest hook for collective_gemm tests"""
parser.addoption("--coordinator-address", action="store", default="localhost:12345")
parser.addoption("--num-processes", action="store", default=1)
parser.addoption("--process-id", action="store", default=0)
parser.addoption("--local-device-ids", action="store", default=None)
parser.addoption("--use-cublasmp", action="store_true", default=False)


@pytest.fixture(autouse=True)
def distributed_args(request):
"""Fixture for querying distributed initialization arguments"""
if request.cls:
use_cublasmp = request.config.getoption("--use-cublasmp")
if use_cublasmp and not nvte_built_with_cublasmp():
pytest.skip(
"Collective GEMM cuBLASMp backend tests require Transformer Engine to be built "
"with NVTE_WITH_CUBLASMP=1."
)
request.cls.coordinator_address = request.config.getoption("--coordinator-address")
request.cls.num_processes = int(request.config.getoption("--num-processes"))
request.cls.process_id = int(request.config.getoption("--process-id"))
Expand All @@ -27,3 +37,4 @@ def distributed_args(request):
if request.cls.local_device_ids is None
else len(request.cls.local_device_ids.split(","))
)
request.cls.use_cublasmp = use_cublasmp
118 changes: 77 additions & 41 deletions examples/jax/collective_gemm/run_test_cgemm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,30 @@ else
echo "NVLINK support detected"
fi

echo "*** Checking cuBLASMp support in TE build ***"
CUBLASMP_SUPPORT=$(python3 - <<'PY'
try:
import transformer_engine.jax
from transformer_engine_jax import nvte_built_with_cublasmp
except Exception as exc:
print(f"error:{exc}")
raise SystemExit(0)

print("1" if nvte_built_with_cublasmp() else "0")
PY
)

if [[ "$CUBLASMP_SUPPORT" == "1" ]]; then
echo "cuBLASMp backend support detected"
BACKENDS=("cublasmp" "userbuffers")
elif [[ "$CUBLASMP_SUPPORT" == "0" ]]; then
echo "cuBLASMp backend support not detected; skipping cuBLASMp backend tests"
BACKENDS=("userbuffers")
else
echo "Failed to query cuBLASMp support from transformer_engine_jax: $CUBLASMP_SUPPORT"
exit 1
fi

# Define individual test cases to run (file::class::method)
# DelayedScalingFP8 and CurrentScalingFP8 use the same GEMM so we don't need to test both cases all
# the time.
Expand Down Expand Up @@ -93,50 +117,62 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
# Clear PIDs array for this test case
PIDS=()

for i in $(seq 0 $(($NUM_GPUS - 1))); do
# Define output file for logs
LOG_FILE="${TEST_NAME}_gpu_${i}.log"

if [ $i -eq 0 ]; then
# For process 0: show live output AND save to log file using tee
echo "=== Live output from process 0 ==="
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_NAME}.xml \
"$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \
--num-processes=$NUM_GPUS \
--process-id=$i 2>&1 | tee "$LOG_FILE" &
PID=$!
PIDS+=($PID)
for BACKEND in "${BACKENDS[@]}"; do
echo "Setting backend to $BACKEND for test $TEST_NAME"

for i in $(seq 0 $(($NUM_GPUS - 1))); do
# Define output file for logs
LOG_FILE="${TEST_NAME}_gpu_${i}_${BACKEND}.log"

test_args=(
"--num-processes=$NUM_GPUS"
"--process-id=$i"
)
if [ "$BACKEND" == "cublasmp" ]; then
test_args+=("--use-cublasmp")
fi

if [ $i -eq 0 ]; then
# For process 0: show live output AND save to log file using tee
echo "=== Live output from process 0 ==="
pytest -s -c "${TE_PATH}/tests/jax/pytest.ini" -vs \
"--junitxml=${XML_LOG_DIR}/${TEST_NAME}_gpu_${i}_${BACKEND}.xml" \
"${TE_PATH}/examples/jax/collective_gemm/${TEST_CASE}" \
"${test_args[@]}" 2>&1 | tee "$LOG_FILE" &
PID=$!
PIDS+=($PID)
else
# For other processes: redirect to log files only
pytest -s -c "${TE_PATH}/tests/jax/pytest.ini" -vs \
"${TE_PATH}/examples/jax/collective_gemm/${TEST_CASE}" \
"${test_args[@]}" > "$LOG_FILE" 2>&1 &
PID=$!
PIDS+=($PID)
fi
done

# Wait for all processes to finish
wait

# Check and print the log content from process 0
if grep -q "SKIPPED" "${TEST_NAME}_gpu_0_${BACKEND}.log"; then
echo "... $TEST_CASE SKIPPED"
elif grep -q "FAILED" "${TEST_NAME}_gpu_0_${BACKEND}.log"; then
echo "... $TEST_CASE FAILED"
HAS_FAILURE=1
elif grep -q "PASSED" "${TEST_NAME}_gpu_0_${BACKEND}.log"; then
echo "... $TEST_CASE PASSED"
else
# For other processes: redirect to log files only
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \
--num-processes=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 &
PID=$!
PIDS+=($PID)
echo "... $TEST_CASE INVALID"
HAS_FAILURE=1
fi
done

# Wait for all processes to finish
wait

# Check and print the log content from process 0
if grep -q "SKIPPED" "${TEST_NAME}_gpu_0.log"; then
echo "... $TEST_CASE SKIPPED"
elif grep -q "FAILED" "${TEST_NAME}_gpu_0.log"; then
echo "... $TEST_CASE FAILED"
HAS_FAILURE=1
elif grep -q "PASSED" "${TEST_NAME}_gpu_0.log"; then
echo "... $TEST_CASE PASSED"
else
echo "... $TEST_CASE INVALID"
HAS_FAILURE=1
fi

# Remove the log files after processing them
wait
rm ${TEST_NAME}_gpu_*.log

# Remove the log files after processing them
wait
rm ${TEST_NAME}_gpu_*_${BACKEND}.log

done
done

wait
Expand Down
1 change: 1 addition & 0 deletions examples/jax/collective_gemm/test_dense_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def setUp(self):
self.args.process_id = self.process_id
self.args.local_device_ids = self.local_device_ids
self.args.num_devices_per_process = self.num_devices_per_process
self.args.use_cublasmp = self.use_cublasmp
self.args.enable_data_parallel = True
self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
_initialize_distributed(self.args)
Expand Down
35 changes: 21 additions & 14 deletions examples/jax/collective_gemm/test_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
from functools import partial

import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding
Expand Down Expand Up @@ -151,20 +152,25 @@ def run_gemm_tests(args, mesh=None):
jax.block_until_ready(gathered_output)

if args.enable_result_check and args.process_id == 0:
# CGEMM + RS + BF16 uses TE's reduce_bf16 kernel (sequential left-to-right in FP32).
# With catastrophic cancellation the output is near zero while the absolute diff can
# reach 1 ULP of the partial GEMM magnitude (~0.0625 for typical transformer
# activations at O(8) scale), which exceeds the previous atol=1e-5. The 2x
# margin (0.125) covers this worst-case 1-ULP absolute difference.
is_cgemm_rs_bf16 = collective_op == CollectiveOp.REDUCE_SCATTER and not use_quantization
rtol = 1e-2 if is_cgemm_rs_bf16 else None
atol = 0.125 if is_cgemm_rs_bf16 else None
assert_allclose(
gathered_ref_output,
gathered_output,
dtype=get_tolerance_dtype(quantizer_set),
rtol=rtol,
atol=atol,
if use_quantization:
rtol, atol = 0.125, 0.0625
else:
rtol, atol = 0.02, 0.002
# Use NumPy (not JAX) for the result check to avoid triggering new XLA compilations
# on process 0 only, which would deadlock in multi-process JAX because XLA compilation
# of distributed arrays requires collective synchronization across all processes.
actual = np.asarray(gathered_output, dtype=np.float32)
desired = np.asarray(gathered_ref_output, dtype=np.float32)
diff = np.abs(actual - desired)
abs_desired = np.abs(desired)
failures = (diff > atol) & (diff > rtol * abs_desired)
num_failures = int(np.sum(failures))
assert num_failures == 0, (
f"NUMERICAL CHECK FAILED: {num_failures}/{diff.size} elements "
f"({100 * num_failures / diff.size:.4f}%) exceed tolerances "
f"(rtol={rtol}, atol={atol}). "
f"Max abs error: {float(np.max(diff)):.6f}, "
f"max rel error: {float(np.max(diff / np.maximum(abs_desired, 1e-5))):.6f}"
)


Expand All @@ -180,6 +186,7 @@ def setUp(self):
self.args.process_id = self.process_id
self.args.local_device_ids = self.local_device_ids
self.args.num_devices_per_process = self.num_devices_per_process
self.args.use_cublasmp = self.use_cublasmp
self.args.enable_data_parallel = True
self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
_initialize_distributed(self.args)
Expand Down
1 change: 1 addition & 0 deletions examples/jax/collective_gemm/test_layernorm_mlp_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def setUp(self):
self.args.process_id = self.process_id
self.args.local_device_ids = self.local_device_ids
self.args.num_devices_per_process = self.num_devices_per_process
self.args.use_cublasmp = self.use_cublasmp
self.args.enable_data_parallel = True
self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
_initialize_distributed(self.args)
Expand Down
Loading