Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3078,6 +3078,128 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)


@pytest.mark.skipif(
torch.cuda.get_device_capability() != (9, 0) and not IS_HIP_EXTENSION,
reason="Only enable CUTLASS/CK grouped gemm on Hopper or ROCm",
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=str)
@pytest.mark.parametrize("layout", ["TN", "NN"])
@pytest.mark.parametrize("accumulate", [False, True])
@pytest.mark.parametrize(
"pad_dim",
["K", "M", "N"],
ids=lambda d: f"pad{d}",
)
def test_grouped_gemm_unaligned(dtype, layout, accumulate, pad_dim):
"""Test CK grouped GEMM with M, N, or K not aligned to CK tile size.

CK constraints for bf16/fp16:
- Contiguous dim of A/B must be dword-aligned (even for 2-byte types).
RowMajor: contiguous dim is cols (K for A, N for B).
ColMajor: contiguous dim is rows (M for A, K for B).
- N: must be multiple of 16 (GetVectorSizeC, no dword fallback), tile 128/256
- K tile: 64, M tile: 256
"""
torch.manual_seed(0)
z = 8

# Unaligned values per dimension (all satisfy CK vector-load constraints).
# K: even but not multiple of tile (64). Same for all groups.
# M: not multiples of tile (256), varies per group.
# N: multiple of 16 but not multiple of tile (128).
unaligned_k = 2026
unaligned_m = [100, 300, 150, 200, 50, 350, 250, 180]
unaligned_n = 2032

# Aligned defaults.
k_aligned = 2048
m_aligned = 256
n_aligned = 2048

os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"

if layout == "TN":
# TN GEMM: M=m_splits[i], N=A.rows, K=A.cols
if pad_dim == "K":
k_val = unaligned_k
m_vals = [m_aligned] * z
n_val = n_aligned
elif pad_dim == "M":
k_val = k_aligned
m_vals = unaligned_m
n_val = n_aligned
else: # N
k_val = k_aligned
m_vals = [m_aligned] * z
n_val = unaligned_n

A = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)]
B = [torch.randn(m, k_val, dtype=dtype, device="cuda") for m in m_vals]
total_m = sum(m_vals)
out = [torch.randn(total_m, n_val, dtype=dtype, device="cuda")]
out_ref = [o.clone() for o in torch.split(out[0], m_vals)]
m_splits = m_vals
grad = False
single_output = True
else: # NN
# NN GEMM: M=m_splits[i], N=A.cols, K=A.rows
if pad_dim == "K":
gemm_k = unaligned_k
m_vals = [m_aligned] * z
n_out = n_aligned
elif pad_dim == "M":
gemm_k = k_aligned
m_vals = unaligned_m
n_out = n_aligned
else: # N
gemm_k = k_aligned
m_vals = [m_aligned] * z
n_out = unaligned_n

A = [torch.randn(gemm_k, n_out, dtype=dtype, device="cuda") for _ in range(z)]
B = [torch.randn(m, gemm_k, dtype=dtype, device="cuda") for m in m_vals]
total_m = sum(m_vals)
out = [torch.randn(total_m, n_out, dtype=dtype, device="cuda")]
out_ref = [o.clone() for o in torch.split(out[0], m_vals)]
m_splits = m_vals
grad = True
single_output = True

# Reference: individual GEMMs
for i in range(z):
general_gemm(
A[i],
B[i],
dtype,
grad=grad,
accumulate=accumulate,
layout=layout,
out=out_ref[i],
)
if single_output:
out_ref = [torch.cat(out_ref)]

general_grouped_gemm(
A,
B,
out,
[None] * z,
dtype,
m_splits=m_splits,
grad=grad,
accumulate=accumulate,
layout=layout,
single_output=single_output,
)

for o, o_ref in zip(out, out_ref):
if IS_HIP_EXTENSION and accumulate and dtype == torch.bfloat16 and get_device_compute_capability() == (9, 4):
torch.testing.assert_close(o, o_ref, rtol=4e-2, atol=4e-2)
else:
torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2)

os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)

@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,17 @@ static inline bool launch_grouped_gemm_kernel(const DescContainer& descs,

if (!Kernel::IsSupportedArgument(kargs)) {
NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. "
"Falling back.");
"transA=", ctx.transA, " transB=", ctx.transB,
" accumulate=", ctx.accumulate, " groups=", ctx.group_num,
". Falling back. "
"CK_Tile constraints for bf16/fp16: "
"contiguous dim of A and B must be dword-aligned (even), "
"N must be multiple of 16 (GetVectorSizeC).");
for (size_t i = 0; i < descs.size(); ++i) {
NVTE_WARN(" group ", i, ": M=", descs[i].M, " N=", descs[i].N, " K=", descs[i].K,
" stride_A=", descs[i].stride_A, " stride_B=", descs[i].stride_B,
" stride_E=", descs[i].stride_E);
}
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@ struct TileCfg_256x128x64 : TileCfg_256x256x64 {
static constexpr ck_tile::index_t N_Tile = 128;
};

struct TileCfg_256x128x64_padding : TileCfg_256x128x64 {
static constexpr bool kPadN = true;
template <typename Base, bool PadM_, bool PadN_, bool PadK_>
struct WithPadding : Base {
static constexpr bool kPadM = PadM_;
static constexpr bool kPadN = PadN_;
static constexpr bool kPadK = PadK_;
};

template <typename AType,
Expand Down Expand Up @@ -196,15 +199,15 @@ class GroupedGemmRunner : public RunnerInterface {
}
};

#define MAKE_RUNNER(TileCfg_) \
#define MAKE_RUNNER(BaseCfg_, kPadM_, kPadN_, kPadK_) \
TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.accumulate, accum_option, { \
using Runner = GroupedGemmRunner<AType, \
BType, \
CType, \
ALayout, \
BLayout, \
CLayout, \
TileCfg_, \
WithPadding<BaseCfg_, kPadM_, kPadN_, kPadK_>, \
accum_option>; \
runner = std::make_unique<Runner>(); \
})
Expand All @@ -216,6 +219,37 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,
const ck_tile::stream_config s{ctx.stream};
std::unique_ptr<RunnerInterface> runner = nullptr;

// Check M and K alignment across all groups.
// All tile configs share the same M_Tile (256) and K_Tile (64).
constexpr ck_tile::index_t M_Tile = TileCfg_256x256x64::M_Tile;
constexpr ck_tile::index_t K_Tile = TileCfg_256x256x64::K_Tile;

bool need_m_pad = false;
bool need_k_pad = false;

for (int i = 0; i < ctx.group_num; ++i) {
const transformer_engine::Tensor* A_te =
transformer_engine::convertNVTETensorCheck(ctx.A[i]);
int64_t Ad0 = 0, Ad1 = 0;
if (get_flat_2d_dims(*A_te, Ad0, Ad1)) {
const int64_t M = ctx.transA ? Ad1 : Ad0;
const int64_t K = ctx.transA ? Ad0 : Ad1;

if (M % M_Tile != 0)
need_m_pad = true;
if (K % K_Tile != 0)
need_k_pad = true;
if (need_m_pad && need_k_pad)
break;
}
}

// CK tile kernel produces incorrect results with kPadK + ColMajor B.
// Fall back to cuBLAS for this combination.
if (need_k_pad && ctx.transB) {
return false;
}

TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, {
using ALayout = std::conditional_t<kTransA, ColMajor, RowMajor>;

Expand All @@ -230,13 +264,17 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, {
using CType = typename TETypeToCKType<d_te_type>::type;

if (ctx.N % 256 == 0) {
MAKE_RUNNER(TileCfg_256x256x64);
} else if (ctx.N % 128 == 0) {
MAKE_RUNNER(TileCfg_256x128x64);
} else {
MAKE_RUNNER(TileCfg_256x128x64_padding);
}
TRANSFORMER_ENGINE_SWITCH_CONDITION(need_m_pad, kPadM, {
TRANSFORMER_ENGINE_SWITCH_CONDITION(need_k_pad, kPadK, {
if (ctx.N % 256 == 0) {
MAKE_RUNNER(TileCfg_256x256x64, kPadM, false, kPadK);
} else if (ctx.N % 128 == 0) {
MAKE_RUNNER(TileCfg_256x128x64, kPadM, false, kPadK);
} else {
MAKE_RUNNER(TileCfg_256x128x64, kPadM, true, kPadK);
}
});
});
});
});
});
Expand Down