Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,25 @@ def FlyROCDL_MmaOpGFX11_WMMA : FlyROCDL_MmaOp<"MmaOpGFX11_WMMA", "gfx11.wmma", [
"int32_t":$k,
"Type":$elemTyA,
"Type":$elemTyB,
"Type":$elemTyAcc
"Type":$elemTyAcc,
// Integer-WMMA controls, forwarded to the ROCDL iu8/iu4 intrinsic. Ignored
// (and required to be false) on fp16/bf16 paths — those intrinsics have no
// such operands. Always-printed in the assembly format for clarity.
"bool":$signA,
"bool":$signB,
"bool":$clamp
);
let assemblyFormat = "`<` custom<MNKDimensionList>($m, $n, $k) `,` `(` $elemTyA `,` $elemTyB `)` `->` $elemTyAcc `>`";
let assemblyFormat = "`<` custom<MNKDimensionList>($m, $n, $k) `,` `(` $elemTyA `,` $elemTyB `)` `->` $elemTyAcc `,` `signA` `=` $signA `,` `signB` `=` $signB `,` `clamp` `=` $clamp `>`";

let builders = [
// Legacy 6-arg builder: defaults signA/signB/clamp=false. Backward-compat
// for fp16/bf16 callers; integer callers should use the 9-arg form below.
TypeBuilderWithInferredContext<(ins "int32_t":$m, "int32_t":$n, "int32_t":$k, "Type":$elemTyA, "Type":$elemTyB, "Type":$elemTyAcc), [{
return $_get(elemTyA.getContext(), m, n, k, elemTyA, elemTyB, elemTyAcc);
return $_get(elemTyA.getContext(), m, n, k, elemTyA, elemTyB, elemTyAcc, false, false, false);
}]>,
// Explicit 9-arg builder for integer-WMMA callers.
TypeBuilderWithInferredContext<(ins "int32_t":$m, "int32_t":$n, "int32_t":$k, "Type":$elemTyA, "Type":$elemTyB, "Type":$elemTyAcc, "bool":$signA, "bool":$signB, "bool":$clamp), [{
return $_get(elemTyA.getContext(), m, n, k, elemTyA, elemTyB, elemTyAcc, signA, signB, clamp);
}]>
];
let genVerifyDecl = 1;
Expand Down
15 changes: 9 additions & 6 deletions lib/Bindings/Python/FlyROCDLExtension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,18 @@ struct PyMmaOpGFX11_WMMAType : PyConcreteType<PyMmaOpGFX11_WMMAType> {
c.def_static(
"get",
[](int32_t m, int32_t n, int32_t k, PyType &elemTyA, PyType &elemTyB, PyType &elemTyAcc,
DefaultingPyMlirContext context) {
return PyMmaOpGFX11_WMMAType(context->getRef(), wrap(MmaOpGFX11_WMMAType::get(
m, n, k, unwrap(elemTyA),
unwrap(elemTyB), unwrap(elemTyAcc))));
bool signA, bool signB, bool clamp, DefaultingPyMlirContext context) {
return PyMmaOpGFX11_WMMAType(
context->getRef(),
wrap(MmaOpGFX11_WMMAType::get(m, n, k, unwrap(elemTyA), unwrap(elemTyB),
unwrap(elemTyAcc), signA, signB, clamp)));
},
"m"_a, "n"_a, "k"_a, "elem_ty_a"_a, "elem_ty_b"_a, "elem_ty_acc"_a, nb::kw_only(),
"context"_a = nb::none(),
"sign_a"_a = false, "sign_b"_a = false, "clamp"_a = false, "context"_a = nb::none(),
"Create a MmaOpGFX11_WMMAType with m, n, k dimensions and element types "
"(RDNA3 / RDNA3.5 wave32 WMMA, v16 operand ABI)");
"(RDNA3 / RDNA3.5 wave32 WMMA, v16 operand ABI). "
"sign_a/sign_b/clamp are forwarded to the iu8/iu4 intrinsic for integer "
"paths; must be false for fp16/bf16.");
}
};

Expand Down
83 changes: 34 additions & 49 deletions lib/Dialect/FlyROCDL/GFX11/MmaAtom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,54 +117,42 @@ Attribute MmaOpGFX11_WMMAType::getThrValLayoutC() const {

LogicalResult MmaOpGFX11_WMMAType::verify(function_ref<InFlightDiagnostic()> emitError, int32_t m,
int32_t n, int32_t k, Type elemTyA, Type elemTyB,
Type elemTyAcc) {
Type elemTyAcc, bool signA, bool signB, bool clamp) {
if (m != 16 || n != 16 || k != 16) {
return emitError() << "GFX11 WMMA requires M=N=K=16, got " << m << "x" << n << "x" << k;
}

bool valid = false;

// fp16/bf16 inputs, f32 accumulator. (16-bit accumulator variants exist on
// RDNA3 but require VGPR-pair packing/expansion around OPSEL; not yet
// implemented here.)
if (elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF32())
valid = true;
if (elemTyA.isBF16() && elemTyB.isBF16() && elemTyAcc.isF32())
valid = true;

// Integer inputs: REQUIRE explicit unsigned signedness (ui8/ui4).
//
// The atom contract is unsigned-only because emitAtomCallSSA invokes the
// ROCDL iu8/iu4 intrinsics with signA=signB=false (unsigned interpretation
// of the packed operands).
auto isUI = [](Type t, unsigned width) {
// Determine which path this is. fp16/bf16 inputs go to the f32-accumulator
// intrinsics, which have no sign/clamp operands. iu8/iu4 inputs go to the
// i32-accumulator intrinsics, which take all three.
const bool isFp = (elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF32()) ||
(elemTyA.isBF16() && elemTyB.isBF16() && elemTyAcc.isF32());

// For integer paths, accept any IntegerType width 8 or 4 regardless of
// signedness (signless/si/ui). The caller controls how the input bits are
// interpreted via signA/signB on the intrinsic.
auto isInt = [](Type t, unsigned width) {
auto it = dyn_cast<IntegerType>(t);
return it && it.getWidth() == width && it.isUnsigned();
return it && it.getWidth() == width;
};
const bool isI8x8 = isInt(elemTyA, 8) && isInt(elemTyB, 8) && elemTyAcc.isInteger(32);
const bool isI4x4 = isInt(elemTyA, 4) && isInt(elemTyB, 4) && elemTyAcc.isInteger(32);
const bool isInt8or4 = isI8x8 || isI4x4;

if (isUI(elemTyA, 8) && isUI(elemTyB, 8) && elemTyAcc.isInteger(32))
valid = true;
if (isUI(elemTyA, 4) && isUI(elemTyB, 4) && elemTyAcc.isInteger(32))
valid = true;

if (!valid) {
// Steer the caller to ui8/ui4 explicitly.
auto looksLikeInt = [](Type t, unsigned w) {
auto it = dyn_cast<IntegerType>(t);
return it && it.getWidth() == w;
};
if ((looksLikeInt(elemTyA, 8) || looksLikeInt(elemTyA, 4)) && elemTyAcc.isInteger(32)) {
return emitError() << "GFX11 WMMA integer inputs must be unsigned "
"(ui8/ui4); got A="
<< elemTyA << ", B=" << elemTyB
<< ". The lowered ROCDL iu8/iu4 intrinsic is invoked "
"with signA=signB=false, so signless/signed "
"operands would silently get unsigned semantics. "
"Signed-integer WMMA is not yet implemented.";
}
if (!isFp && !isInt8or4) {
return emitError() << "unsupported GFX11 WMMA configuration: " << m << "x" << n << "x" << k
<< " with A=" << elemTyA << ", B=" << elemTyB << ", Acc=" << elemTyAcc;
}

// fp16/bf16 intrinsics do not have signA/signB/clamp operands. Refuse to
// construct an atom that promises something the codegen cannot deliver.
if (isFp && (signA || signB || clamp)) {
return emitError() << "GFX11 WMMA fp16/bf16 path does not accept signA/signB/clamp "
"(the ROCDL fp WMMA intrinsics have no such operands); "
"got signA="
<< signA << ", signB=" << signB << ", clamp=" << clamp;
}

return success();
}

Expand Down Expand Up @@ -247,7 +235,6 @@ FailureOr<Value> MmaOpGFX11_WMMAType::emitAtomCallSSA(OpBuilder &builder, Locati
StringRef opName;
SmallVector<NamedAttribute, 3> attrs;
SmallVector<Value, 3> operands;
BoolAttr falseAttr = builder.getBoolAttr(false);

if (elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF32()) {
opName = ROCDL::wmma_f32_16x16x16_f16::getOperationName();
Expand All @@ -256,21 +243,19 @@ FailureOr<Value> MmaOpGFX11_WMMAType::emitAtomCallSSA(OpBuilder &builder, Locati
opName = ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
operands = {a, b, c};
} else if (elemTyA.isInteger(8) && elemTyB.isInteger(8) && elemTyAcc.isInteger(32)) {
// Unsigned-only by contract (see verify()). signA=signB=false matches the
// ui8 element type enforced there. clamp=false preserves wraparound on the
// i32 accumulator.
// Integer paths: signA/signB/clamp come from the type parameters so the
// caller controls whether each operand is interpreted as signed.
opName = ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
operands = {a, b, c};
attrs.push_back({builder.getStringAttr("signA"), falseAttr});
attrs.push_back({builder.getStringAttr("signB"), falseAttr});
attrs.push_back({builder.getStringAttr("clamp"), falseAttr});
attrs.push_back({builder.getStringAttr("signA"), builder.getBoolAttr(getSignA())});
attrs.push_back({builder.getStringAttr("signB"), builder.getBoolAttr(getSignB())});
attrs.push_back({builder.getStringAttr("clamp"), builder.getBoolAttr(getClamp())});
} else if (elemTyA.isInteger(4) && elemTyB.isInteger(4) && elemTyAcc.isInteger(32)) {
// Same unsigned-only contract as iu8; see verify().
opName = ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
operands = {a, b, c};
attrs.push_back({builder.getStringAttr("signA"), falseAttr});
attrs.push_back({builder.getStringAttr("signB"), falseAttr});
attrs.push_back({builder.getStringAttr("clamp"), falseAttr});
attrs.push_back({builder.getStringAttr("signA"), builder.getBoolAttr(getSignA())});
attrs.push_back({builder.getStringAttr("signB"), builder.getBoolAttr(getSignB())});
attrs.push_back({builder.getStringAttr("clamp"), builder.getBoolAttr(getClamp())});
} else {
return failure();
}
Expand Down
15 changes: 13 additions & 2 deletions python/flydsl/expr/rocdl/universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,16 @@ def MFMA(m, n, k, elem_ty_ab, elem_ty_acc=None):
return MmaOpCDNA3_MFMAType.get(m, n, k, ty_ab, ty_ab, ty_acc)


def WMMA(m, n, k, elem_ty_ab, elem_ty_acc=None):
def WMMA(m, n, k, elem_ty_ab, elem_ty_acc=None, *, sign_a=False, sign_b=False, clamp=False):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better to pack new arguments to a **kwargs? Leave argument space for new other wmma ops.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think putting those params to kwargs is a good idea since they control numerical correctness of the intrinsic. Silent-pass **kwargs would mask typos in flags. I'd rather refactor that to an arch specific dataclass as soon as we find signature too polluted after adding a new architecture.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we later split it into an specific opClass, can we still keep existing WMMA kernel backward compatible?
Would adding some comments be enough to make programmers notice that? :D

"""Create an arch-appropriate WMMA atom.

sign_a / sign_b / clamp are only meaningful on the gfx11 integer paths
(iu8 / iu4). They are forwarded to the ROCDL intrinsic so callers can pick
signed vs unsigned interpretation per-operand and request saturation. On
fp16/bf16 paths they must remain False (the fp intrinsics have no such
operands; verify() will reject otherwise). The gfx12 (RDNA4) path does not
expose these knobs yet.
"""
ty_ab = elem_ty_ab.ir_type if hasattr(elem_ty_ab, "ir_type") else elem_ty_ab
if elem_ty_acc is None:
ty_acc = ir.F32Type.get()
Expand All @@ -90,8 +99,10 @@ def WMMA(m, n, k, elem_ty_ab, elem_ty_acc=None):

arch = (get_rocm_arch() or "").lower()
if arch.startswith("gfx11"):
return MmaOpGFX11_WMMAType.get(m, n, k, ty_ab, ty_ab, ty_acc)
return MmaOpGFX11_WMMAType.get(m, n, k, ty_ab, ty_ab, ty_acc, sign_a=sign_a, sign_b=sign_b, clamp=clamp)
if arch.startswith("gfx12"):
if sign_a or sign_b or clamp:
raise ValueError("sign_a/sign_b/clamp are not supported on the gfx12 (RDNA4) WMMA path yet")
return MmaOpGFX1250_WMMAType.get(m, n, k, ty_ab, ty_ab, ty_acc)
raise ValueError(
f"WMMA is not available on target arch {arch!r}; " "supported: gfx11xx (RDNA3 / RDNA3.5) and gfx12xx (RDNA4). "
Expand Down
70 changes: 54 additions & 16 deletions tests/mlir/Conversion/wmma_gfx11.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
// Wave32 fragment shapes per lane (16x16x16 bf16 -> f32):
// A, B : 16 bf16 elements (lowered to vector<16xi16> for the intrinsic)
// C, D : 8 f32 accumulator slots (vector<8xf32>)
//
// The atom type carries three always-printed bool parameters
// (signA, signB, clamp) — forwarded to the ROCDL iu8/iu4 intrinsic on
// integer paths, and required to be false on fp16/bf16 paths.

// CHECK-LABEL: @test_gfx11_wmma_atom_call_bf16
// CHECK-SAME: (%[[D:.*]]: !llvm.ptr<5>, %[[A:.*]]: !llvm.ptr<5>, %[[B:.*]]: !llvm.ptr<5>, %[[C:.*]]: !llvm.ptr<5>)
Expand All @@ -16,7 +20,7 @@ func.func @test_gfx11_wmma_atom_call_bf16(
%a: !fly.memref<bf16, register, 16:1>,
%b: !fly.memref<bf16, register, 16:1>,
%c: !fly.memref<f32, register, 8:1>) {
%atom = fly.make_mma_atom : !fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (bf16, bf16) -> f32>>
%atom = fly.make_mma_atom : !fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (bf16, bf16) -> f32, signA = false, signB = false, clamp = false>>
// Loads land directly in the i16 representation expected by the WMMA intrinsic
// (the bf16->i16 reinterpretation happens at type-conversion time, not via a
// separate llvm.bitcast like the SSA path below).
Expand All @@ -25,19 +29,19 @@ func.func @test_gfx11_wmma_atom_call_bf16(
// CHECK: %[[C_VAL:.*]] = llvm.load %[[C]] : !llvm.ptr<5> -> vector<8xf32>
// CHECK: %[[RES:.*]] = rocdl.wmma.f32.16x16x16.bf16 %[[A_VAL]], %[[B_VAL]], %[[C_VAL]]
// CHECK: llvm.store %[[RES]], %[[D]] : vector<8xf32>, !llvm.ptr<5>
fly.mma_atom_call(%atom, %d, %a, %b, %c) : (!fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (bf16, bf16) -> f32>>, !fly.memref<f32, register, 8:1>, !fly.memref<bf16, register, 16:1>, !fly.memref<bf16, register, 16:1>, !fly.memref<f32, register, 8:1>) -> ()
fly.mma_atom_call(%atom, %d, %a, %b, %c) : (!fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (bf16, bf16) -> f32, signA = false, signB = false, clamp = false>>, !fly.memref<f32, register, 8:1>, !fly.memref<bf16, register, 16:1>, !fly.memref<bf16, register, 16:1>, !fly.memref<f32, register, 8:1>) -> ()
return
}

// CHECK-LABEL: @test_gfx11_wmma_gemm_from_tiled_mma_arg
// CHECK: rocdl.wmma.f32.16x16x16.bf16
func.func @test_gfx11_wmma_gemm_from_tiled_mma_arg(
%tiled_mma: !fly.tiled_mma<!fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (bf16, bf16) -> f32>>, <(2,4,1):(4,1,0)>>,
%tiled_mma: !fly.tiled_mma<!fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (bf16, bf16) -> f32, signA = false, signB = false, clamp = false>>, <(2,4,1):(4,1,0)>>,
%d: !fly.memref<f32, register, 8:1>,
%a: !fly.memref<bf16, register, 16:1>,
%b: !fly.memref<bf16, register, 16:1>,
%c: !fly.memref<f32, register, 8:1>) {
fly.gemm(%tiled_mma, %d, %a, %b, %c) : (!fly.tiled_mma<!fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (bf16, bf16) -> f32>>, <(2,4,1):(4,1,0)>>, !fly.memref<f32, register, 8:1>, !fly.memref<bf16, register, 16:1>, !fly.memref<bf16, register, 16:1>, !fly.memref<f32, register, 8:1>) -> ()
fly.gemm(%tiled_mma, %d, %a, %b, %c) : (!fly.tiled_mma<!fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (bf16, bf16) -> f32, signA = false, signB = false, clamp = false>>, <(2,4,1):(4,1,0)>>, !fly.memref<f32, register, 8:1>, !fly.memref<bf16, register, 16:1>, !fly.memref<bf16, register, 16:1>, !fly.memref<f32, register, 8:1>) -> ()
return
}

Expand All @@ -47,11 +51,11 @@ func.func @test_gfx11_wmma_atom_call_ssa_bf16(
%a: vector<16xbf16>,
%b: vector<16xbf16>,
%c: vector<8xf32>) -> vector<8xf32> {
%atom = fly.make_mma_atom : !fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (bf16, bf16) -> f32>>
%atom = fly.make_mma_atom : !fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (bf16, bf16) -> f32, signA = false, signB = false, clamp = false>>
// CHECK: %[[A_CAST:.*]] = llvm.bitcast %[[A]] : vector<16xbf16> to vector<16xi16>
// CHECK: %[[B_CAST:.*]] = llvm.bitcast %[[B]] : vector<16xbf16> to vector<16xi16>
// CHECK: %[[RES:.*]] = rocdl.wmma.f32.16x16x16.bf16 %[[A_CAST]], %[[B_CAST]], %[[C]]
%res = fly.mma_atom_call_ssa(%atom, %a, %b, %c) : (!fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (bf16, bf16) -> f32>>, vector<16xbf16>, vector<16xbf16>, vector<8xf32>) -> vector<8xf32>
%res = fly.mma_atom_call_ssa(%atom, %a, %b, %c) : (!fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (bf16, bf16) -> f32, signA = false, signB = false, clamp = false>>, vector<16xbf16>, vector<16xbf16>, vector<8xf32>) -> vector<8xf32>
return %res : vector<8xf32>
}

Expand All @@ -61,26 +65,60 @@ func.func @test_gfx11_wmma_atom_call_ssa_f16(
%a: vector<16xf16>,
%b: vector<16xf16>,
%c: vector<8xf32>) -> vector<8xf32> {
%atom = fly.make_mma_atom : !fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (f16, f16) -> f32>>
%atom = fly.make_mma_atom : !fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (f16, f16) -> f32, signA = false, signB = false, clamp = false>>
// CHECK: %[[RES:.*]] = rocdl.wmma.f32.16x16x16.f16 %[[A]], %[[B]], %[[C]]
%res = fly.mma_atom_call_ssa(%atom, %a, %b, %c) : (!fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (f16, f16) -> f32>>, vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
%res = fly.mma_atom_call_ssa(%atom, %a, %b, %c) : (!fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (f16, f16) -> f32, signA = false, signB = false, clamp = false>>, vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
return %res : vector<8xf32>
}

// Unsigned integer (ui8) inputs must lower to rocdl.wmma.i32.16x16x16.iu8 with
// signA=false, signB=false, clamp=false — matching the unsigned-only contract
// documented in verify(). The A/B operands (vector<16xui8>) are bitcast to the
// packed representation (vector<4xi32>) expected by the intrinsic.
// Unsigned i8 inputs (signA=signB=false, clamp=false) lower to
// rocdl.wmma.i32.16x16x16.iu8 with the corresponding attrs. A/B operands
// (vector<16xui8>) are bitcast to the packed representation (vector<4xi32>)
// expected by the intrinsic.
//
// CHECK-LABEL: @test_gfx11_wmma_atom_call_ssa_iu8
func.func @test_gfx11_wmma_atom_call_ssa_iu8(
// CHECK-LABEL: @test_gfx11_wmma_atom_call_ssa_iu8_unsigned
func.func @test_gfx11_wmma_atom_call_ssa_iu8_unsigned(
%a: vector<16xui8>,
%b: vector<16xui8>,
%c: vector<8xi32>) -> vector<8xi32> {
%atom = fly.make_mma_atom : !fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (ui8, ui8) -> i32>>
%atom = fly.make_mma_atom : !fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (ui8, ui8) -> i32, signA = false, signB = false, clamp = false>>
// CHECK: llvm.bitcast {{.*}} : vector<16xui8> to vector<4xi32>
// CHECK: llvm.bitcast {{.*}} : vector<16xui8> to vector<4xi32>
// CHECK: rocdl.wmma.i32.16x16x16.iu8
%res = fly.mma_atom_call_ssa(%atom, %a, %b, %c) : (!fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (ui8, ui8) -> i32>>, vector<16xui8>, vector<16xui8>, vector<8xi32>) -> vector<8xi32>
// signA=signB=clamp=false attrs are elided by the printer when at default.
%res = fly.mma_atom_call_ssa(%atom, %a, %b, %c) : (!fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (ui8, ui8) -> i32, signA = false, signB = false, clamp = false>>, vector<16xui8>, vector<16xui8>, vector<8xi32>) -> vector<8xi32>
return %res : vector<8xi32>
}

// Mixed-sign i8 (signed A x unsigned B, no clamp) — the three knobs are
// independent type params, so any combination must round-trip and forward to
// the iu8 intrinsic. The printer elides false-valued attrs, so we only assert
// the non-default signA = true here.
//
// CHECK-LABEL: @test_gfx11_wmma_atom_call_ssa_iu8_mixed_sign
func.func @test_gfx11_wmma_atom_call_ssa_iu8_mixed_sign(
%a: vector<16xi8>,
%b: vector<16xi8>,
%c: vector<8xi32>) -> vector<8xi32> {
%atom = fly.make_mma_atom : !fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (i8, i8) -> i32, signA = true, signB = false, clamp = false>>
// CHECK: rocdl.wmma.i32.16x16x16.iu8 {{.*}} {signA = true}
%res = fly.mma_atom_call_ssa(%atom, %a, %b, %c) : (!fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (i8, i8) -> i32, signA = true, signB = false, clamp = false>>, vector<16xi8>, vector<16xi8>, vector<8xi32>) -> vector<8xi32>
return %res : vector<8xi32>
}

// Signed i8 inputs use signA=signB=true so the intrinsic treats the packed
// bytes as signed. clamp=true requests saturation on the i32 accumulator
// path. This exercises the dev-suggested type-parameter knobs end-to-end.
//
// CHECK-LABEL: @test_gfx11_wmma_atom_call_ssa_iu8_signed_clamp
func.func @test_gfx11_wmma_atom_call_ssa_iu8_signed_clamp(
%a: vector<16xsi8>,
%b: vector<16xsi8>,
%c: vector<8xi32>) -> vector<8xi32> {
%atom = fly.make_mma_atom : !fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (si8, si8) -> i32, signA = true, signB = true, clamp = true>>
// CHECK: llvm.bitcast {{.*}} : vector<16xsi8> to vector<4xi32>
// CHECK: llvm.bitcast {{.*}} : vector<16xsi8> to vector<4xi32>
// CHECK: rocdl.wmma.i32.16x16x16.iu8 {{.*}} {clamp = true, signA = true, signB = true}
%res = fly.mma_atom_call_ssa(%atom, %a, %b, %c) : (!fly.mma_atom<!fly_rocdl.gfx11.wmma<16x16x16, (si8, si8) -> i32, signA = true, signB = true, clamp = true>>, vector<16xsi8>, vector<16xsi8>, vector<8xi32>) -> vector<8xi32>
return %res : vector<8xi32>
}
Loading
Loading