Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
8de5bb5
init einsum
phu0ngng Dec 3, 2025
1f02cf4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 3, 2025
bf3ebc2
code drop
pggPL Dec 10, 2025
76293d4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2025
296d773
Add FP8 scale support and fix alignment for grouped GEMM
pggPL Dec 10, 2025
785df34
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2025
1329b37
fix
pggPL Dec 10, 2025
47c58be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2025
a155a8a
Grouped GEMM: code cleanup and NULL C support
pggPL Dec 11, 2025
3b2fcdf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2025
5b0582b
Grouped GEMM: per-matrix alpha/beta support
pggPL Dec 11, 2025
101766b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2025
1167f75
Fix alpha/beta numel - use SimpleTensor::numel()
pggPL Dec 11, 2025
a5ee92f
Merge branch 'main' into einsum
jberchtold-nvidia Dec 16, 2025
00eb186
Einsum WIP 1
jberchtold-nvidia Dec 17, 2025
38defb8
Test
jberchtold-nvidia Dec 18, 2025
e4a80a3
Refactor: move grouped GEMM to separate file and cleanup API
pggPL Dec 19, 2025
db1e177
Merge branch 'main' into grouped_gemm
pggPL Dec 19, 2025
047a9f9
fix
pggPL Dec 19, 2025
c490e06
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 19, 2025
e397845
batching working correctly for quant and gemm but slow
jberchtold-nvidia Dec 19, 2025
59145cc
fix
pggPL Dec 22, 2025
77b422a
Require Blackwell (SM100) and cuBLAS 13.1+ for grouped GEMM
pggPL Dec 22, 2025
9c8158e
fix
pggPL Dec 22, 2025
b1e0893
fix
jberchtold-nvidia Dec 22, 2025
f70f376
Merge remote-tracking branch 'github-upstream/main' into einsum
jberchtold-nvidia Dec 23, 2025
fb2067b
move einsum logic into TE
jberchtold-nvidia Dec 23, 2025
30716a6
einsum unit tests
jberchtold-nvidia Dec 23, 2025
349c315
fwd bwd einsum test
jberchtold-nvidia Dec 23, 2025
57ab3b0
unit tests passed with grouped gemm in bf16
jberchtold-nvidia Dec 23, 2025
ab98852
grouped quantization working for single gpu
jberchtold-nvidia Dec 23, 2025
1184796
Merge remote-tracking branch 'pawel/grouped_gemm' into einsum
jberchtold-nvidia Dec 23, 2025
ed540c8
fixes
pggPL Dec 30, 2025
359a9f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 30, 2025
a702426
fixes
pggPL Dec 30, 2025
fb027d0
fix
pggPL Dec 30, 2025
ae85415
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 30, 2025
f1fc31c
wip
jberchtold-nvidia Jan 5, 2026
43f7e60
Update transformer_engine/common/gemm/config.h
pggPL Jan 7, 2026
30468af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2026
2ccaee5
changed
pggPL Jan 7, 2026
bd8fa30
suggestions
pggPL Jan 7, 2026
f0df80e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2026
301874d
fix
pggPL Jan 7, 2026
c8cf763
with many hacks grouped gemm with new api works for a particular hard…
jberchtold-nvidia Jan 7, 2026
21e7002
progress
jberchtold-nvidia Jan 7, 2026
1ae08dd
more tests pass
jberchtold-nvidia Jan 7, 2026
fe39e39
einsum tests pass
jberchtold-nvidia Jan 7, 2026
5e47d57
more progress, works in maxtext single-gpu and is closer to bf16 batc…
jberchtold-nvidia Jan 8, 2026
bc6cf66
attempt at passing thru stateful args for DS
jberchtold-nvidia Jan 8, 2026
bcbe864
Revert "attempt at passing thru stateful args for DS"
jberchtold-nvidia Jan 8, 2026
b40353f
batch gemm specialization for CS amax calc
jberchtold-nvidia Jan 8, 2026
6c5d969
fix
pggPL Jan 9, 2026
c91cd8f
fix
pggPL Jan 9, 2026
0319e79
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2026
a14d5bc
refactored hopper tensor selection
pggPL Jan 13, 2026
ee8f3ef
Merge remote-tracking branch 'origin/grouped_gemm' into grouped_gemm
pggPL Jan 13, 2026
c5c2fbf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 13, 2026
ee71c96
multi-GPU grouped quantize working now in shard_map (with hack to use…
jberchtold-nvidia Jan 15, 2026
9856862
reduce size of zero'ing memset to only uninitialized part of quantiza…
jberchtold-nvidia Jan 15, 2026
f2ada5a
Merge remote-tracking branch 'pawel/grouped_gemm' into einsum
jberchtold-nvidia Jan 15, 2026
23b5de3
fix TE/JAX to work compile with latest nvte_grouped_gemm API changes
jberchtold-nvidia Jan 15, 2026
179aab6
some tests starting to work
jberchtold-nvidia Jan 21, 2026
6a54ff8
wip
jberchtold-nvidia Jan 21, 2026
8c86a86
wip
jberchtold-nvidia Jan 21, 2026
d8247da
wip
jberchtold-nvidia Jan 21, 2026
d799a29
wip
jberchtold-nvidia Jan 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions test_einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from enum import Enum

import jax
import jax.numpy as jnp
import numpy as np
import transformer_engine.jax as te
from transformer_engine.common.recipe import Recipe, Float8CurrentScaling, MXFP8BlockScaling, DelayedScaling, NVFP4BlockScaling
from flax import linen as nn

def make_einsum_cls(quantization_recipe):
def te_einsum(generate_quantizer_set, s, x, kernel, **kwargs):
def dot_general(x, kernel, dims, *args, **kwargs):
contracting_dims, batch_dims = dims
assert batch_dims == ((), ()), "Batch dims not supported in TE/JAX yet"

quantizer_set = generate_quantizer_set("quantizer_set_for_einsum")
return te.dense.dense(
x,
kernel,
contracting_dims=contracting_dims,
quantizer_set=quantizer_set,
)
return jnp.einsum(s, x, kernel, _dot_general=dot_general, **kwargs)

return te.flax.wrap_function_in_te_state_module(te_einsum, quantization_recipe, "einsum")()

class EinsumType(Enum):
JAX = 'jax'
TE = 'te'

def main():

class SimpleModel(nn.Module):

einsum_type: EinsumType
quantization_recipe: Recipe = None

def _einsum(self, *args, **kwargs):
if self.einsum_type == EinsumType.JAX:
return jnp.einsum(*args, **kwargs)
elif self.einsum_type == EinsumType.TE:
# It is important that we call make_einsum_cls(recipe) here each time einsum
# is called. If we were to call make_einsum_cls only once and re-use it, the state for some recipes such as DelayedScaling would become incorrectly shared instead of each call having its own state.
return make_einsum_cls(self.quantization_recipe)(*args, **kwargs)
else:
raise ValueError(f"Unsupported einsum type: {self.einsum_type}")

@nn.compact
def __call__(self, x):
kernel = self.param('kernel', jax.nn.initializers.lecun_normal(), (32, 32), jnp.bfloat16)
return self._einsum("ij,jk->ik", x, kernel)


def test_model(einsum_type: EinsumType, quantization_recipe: Recipe = None):
model = SimpleModel(einsum_type=einsum_type, quantization_recipe=quantization_recipe)
x = jax.random.uniform(jax.random.PRNGKey(2), (32, 32), jnp.bfloat16)
var_collect = model.init(jax.random.PRNGKey(3), x)
# It is important to use var_collect here to ensure all state (e.g., quantizer states) is properly handled. If you use var_collect['params'] only, TE's state management will not work correctly for recipes that require state (e.g. DelayedScaling).
y = model.apply(var_collect, x)
return y

# einsum_cls = None, so standard JAX computation
ref_out = test_model(einsum_type=EinsumType.JAX)

# einsum using Transformer Engine's Float8CurrentScaling recipe
te_out = test_model(einsum_type=EinsumType.TE, quantization_recipe=Float8CurrentScaling())

# Compare outputs
atol = float(jnp.finfo(jnp.float8_e4m3fn).eps)
np.testing.assert_allclose(ref_out, te_out, atol=atol)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ add_executable(test_operator
test_causal_softmax.cu
test_swizzle.cu
test_swap_first_dims.cu
test_grouped_gemm.cu
../test_common.cu)

# Find required packages
Expand Down
308 changes: 308 additions & 0 deletions tests/cpp/operator/test_grouped_gemm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include <cublasLt.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>

#include <algorithm>
#include <memory>
#include <numeric>
#include <optional>
#include <random>
#include <tuple>
#include <vector>

#include <transformer_engine/cast.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/recipe.h>
#include <transformer_engine/transformer_engine.h>

#include "../test_common.h"

using namespace transformer_engine;
using namespace test;

namespace {

enum class InputCase {
kFP8Current,
kBF16,
};

enum class ShapeCase {
kAllSame,
kSameFirst,
kSameLast,
kAllDifferent,
};

size_t grouped_setup_workspace_size(const size_t num_tensors) {
const size_t ptr_bytes = num_tensors * sizeof(void*);
const size_t int_bytes = num_tensors * sizeof(int);
// Layout: 6 pointer arrays (A, B, C, D, alpha, beta) + 6 int arrays (a_rows, a_cols, b_rows, b_cols, d_rows, d_cols)
size_t size = 6 * ptr_bytes + 6 * int_bytes;
const size_t alignment = 256;
size = ((size + alignment - 1) / alignment) * alignment;
return size;
}

Tensor make_fp8_operand(const std::string& name, const std::vector<size_t>& shape) {
Tensor input_fp32(name + "_fp32", shape, DType::kFloat32);
fillUniform(&input_fp32);

Tensor fp8(name, shape, TypeInfo<fp8e4m3>::dtype, true, true, NVTE_DELAYED_TENSOR_SCALING);

nvte_compute_amax(input_fp32.data(), fp8.data(), 0);
QuantizationConfigWrapper config;
nvte_compute_scale_from_amax(fp8.data(), config, 0);
nvte_quantize(input_fp32.data(), fp8.data(), 0);
return fp8;
}

Tensor make_bf16_operand(const std::string& name, const std::vector<size_t>& shape) {
Tensor t(name, shape, DType::kBFloat16);
const size_t numel = shape[0] * shape[1];
std::vector<__nv_bfloat16> ones(numel, __float2bfloat16(1.0f));
NVTE_CHECK_CUDA(cudaMemcpy(t.rowwise_dptr(), ones.data(),
numel * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice));
return t;
}

struct TestParams {
InputCase input_case;
bool transa;
bool transb;
ShapeCase shape_case;
bool use_null_c = false; // When true, pass nullptr for C (valid when beta=0)
};

// Returns a vector of (M, N, K) tuples for each GEMM in the group.
// M - number of rows in output D
// N - number of columns in output D
// K - reduction dimension shared between A and B
std::vector<std::tuple<size_t, size_t, size_t>> make_shapes(ShapeCase scase) {
switch (scase) {
case ShapeCase::kAllSame:
return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}};
case ShapeCase::kSameFirst:
// Same M (first dim), varying N and K
return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}};
case ShapeCase::kSameLast:
// Same N (last dim), varying M and K
return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}};
case ShapeCase::kAllDifferent:
default:
return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}};
}
}

void run_grouped_gemm_case(const TestParams& params) {
#if CUBLAS_VERSION < 130100
GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.1+, but compile-time cuBLAS version is "
<< CUBLAS_VERSION << ".";
#else
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer.";
}

const std::vector<std::tuple<size_t, size_t, size_t>> shapes = make_shapes(params.shape_case);

const size_t num_gemms = shapes.size();
std::vector<Tensor> A_tensors;
std::vector<Tensor> B_tensors;
std::vector<Tensor> D_multi;

A_tensors.reserve(num_gemms);
B_tensors.reserve(num_gemms);
D_multi.reserve(num_gemms);

for (size_t i = 0; i < num_gemms; ++i) {
const auto [M, N, K] = shapes[i];
const std::vector<size_t> a_shape = params.transa ? std::vector<size_t>{M, K}
: std::vector<size_t>{K, M};
const std::vector<size_t> b_shape = params.transb ? std::vector<size_t>{K, N}
: std::vector<size_t>{N, K};
switch (params.input_case) {
case InputCase::kFP8Current: {
A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape));
B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape));
break;
}
case InputCase::kBF16: {
A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape));
B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape));
break;
}
}
D_multi.emplace_back(Tensor("D_multi" + std::to_string(i),
std::vector<size_t>{M, N},
DType::kBFloat16));
}

std::vector<NVTETensor> A_ptrs(num_gemms);
std::vector<NVTETensor> B_ptrs(num_gemms);
std::vector<NVTETensor> D_ptrs(num_gemms);
std::vector<Tensor> workspaces(num_gemms);
std::vector<NVTETensor> workspace_ptrs(num_gemms, nullptr);
std::vector<Tensor*> A_views;
std::vector<Tensor*> B_views;
A_views.reserve(num_gemms);
B_views.reserve(num_gemms);

// Empty bias/gelu arrays for nvte_multi_tensor_gemm (no epilogues)
std::vector<NVTETensor> bias_ptrs(num_gemms, nullptr);
std::vector<NVTETensor> gelu_ptrs(num_gemms, nullptr);

const size_t cublas_ws_bytes = 32ull * 1024 * 1024;

for (size_t i = 0; i < num_gemms; ++i) {
A_ptrs[i] = A_tensors[i].data();
B_ptrs[i] = B_tensors[i].data();
D_ptrs[i] = D_multi[i].data();
workspaces[i] = Tensor("workspace" + std::to_string(i), std::vector<size_t>{cublas_ws_bytes}, DType::kByte);
workspace_ptrs[i] = workspaces[i].data();
A_views.push_back(&A_tensors[i]);
B_views.push_back(&B_tensors[i]);
}

nvte_multi_tensor_gemm(A_ptrs.data(),
B_ptrs.data(),
D_ptrs.data(),
bias_ptrs.data(),
gelu_ptrs.data(),
static_cast<int>(num_gemms),
params.transa,
params.transb,
false, // grad
workspace_ptrs.data(),
false, // accumulate
false, // use_split_accumulator
0, // sm_count
0);

GroupedBuffers grouped_A = build_grouped_tensor(A_views, A_tensors[0].scaling_mode());
GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode());

std::vector<Tensor> C_tensors;
std::vector<Tensor> D_group_tensors;
C_tensors.reserve(num_gemms);
D_group_tensors.reserve(num_gemms);
for (size_t i = 0; i < num_gemms; ++i) {
const auto [M, N, K] = shapes[i];
(void)K;
if (!params.use_null_c) {
C_tensors.emplace_back(Tensor("C" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(M), static_cast<size_t>(N)},
DType::kBFloat16));
}
D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(M), static_cast<size_t>(N)},
DType::kBFloat16));
NVTE_CHECK_CUDA(cudaMemset(D_group_tensors.back().rowwise_dptr(), 0, bytes(D_group_tensors.back().rowwise_shape(), D_group_tensors.back().dtype())));
}

std::vector<Tensor*> C_views, D_views;
for (size_t i = 0; i < num_gemms; ++i) {
if (!params.use_null_c) {
C_views.push_back(&C_tensors[i]);
}
D_views.push_back(&D_group_tensors[i]);
}

std::optional<GroupedBuffers> grouped_C;
if (!params.use_null_c) {
grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING);
}
GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING);

// Per-matrix alpha/beta (all 1.0 and 0.0 respectively)
Tensor alpha_tensor("alpha", std::vector<size_t>{num_gemms}, DType::kFloat32);
Tensor beta_tensor("beta", std::vector<size_t>{num_gemms}, DType::kFloat32);
std::vector<float> alpha_vals(num_gemms, 1.f);
std::vector<float> beta_vals(num_gemms, 0.f);
NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(),
num_gemms * sizeof(float), cudaMemcpyHostToDevice));
NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(),
num_gemms * sizeof(float), cudaMemcpyHostToDevice));

const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms);
Tensor setup_ws("setup_ws", std::vector<size_t>{setup_ws_bytes}, DType::kByte);
Tensor cublas_ws("cublas_ws", std::vector<size_t>{cublas_ws_bytes}, DType::kByte);

nvte_grouped_gemm(grouped_A.get_handle(),
params.transa,
grouped_B.get_handle(),
params.transb,
params.use_null_c ? nullptr : grouped_C->get_handle(),
grouped_D.get_handle(),
alpha_tensor.data(),
beta_tensor.data(),
setup_ws.data(),
cublas_ws.data(),
nullptr, // config (use defaults)
0);

for (size_t i = 0; i < num_gemms; ++i) {
Tensor grouped_split("grouped_D" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(std::get<0>(shapes[i])),
static_cast<size_t>(std::get<1>(shapes[i]))},
D_multi[i].dtype());
const size_t offset_bytes = static_cast<size_t>(grouped_D.offsets_host[i]) * grouped_D.elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(),
static_cast<char*>(grouped_D.get_data()) + offset_bytes,
grouped_D.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
grouped_split.to_cpu();
D_multi[i].to_cpu();
auto [atol, rtol] = getTolerances(D_multi[i].dtype());
compareResults("grouped_vs_multi",
grouped_split,
D_multi[i].rowwise_cpu_dptr<bf16>(),
true,
atol,
rtol);
}
#endif // CUBLAS_VERSION >= 130100
}

class GroupedGemmTest : public ::testing::TestWithParam<TestParams> {};

TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) {
run_grouped_gemm_case(GetParam());
}

std::string MakeGroupedGemmTestName(const testing::TestParamInfo<GroupedGemmTest::ParamType>& info) {
constexpr const char* kInputNames[] = {"FP8Current", "BF16"};
constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"};
const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") +
"tb" + (info.param.transb ? "T" : "N");
const std::string null_c = info.param.use_null_c ? "_NullC" : "";
return std::string(kInputNames[static_cast<int>(info.param.input_case)]) + "_" +
kShapeNames[static_cast<int>(info.param.shape_case)] + "_" + layout + null_c;
}

// TestParams: {input_case, transa, transb, shape_case, use_null_c}
const std::vector<TestParams> kTestParams = {
// Basic tests
{InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false},
{InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false},
{InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false},
{InputCase::kBF16, true, false, ShapeCase::kSameFirst, false},
{InputCase::kBF16, false, true, ShapeCase::kSameLast, false},
{InputCase::kBF16, false, false, ShapeCase::kAllSame, false},
{InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false},
// Test NULL C (valid when beta=0)
{InputCase::kBF16, false, false, ShapeCase::kAllSame, true},
};

INSTANTIATE_TEST_SUITE_P(OperatorTest,
GroupedGemmTest,
::testing::ValuesIn(kTestParams),
MakeGroupedGemmTestName);

} // namespace
Loading