Skip to content

Addition of MLSS initial kernels for GFX1201#4831

Open
Zhaeong wants to merge 25 commits into
ROCm:developfrom
Zhaeong:amd/dev/ozhang/mhamlss
Open

Addition of MLSS initial kernels for GFX1201#4831
Zhaeong wants to merge 25 commits into
ROCm:developfrom
Zhaeong:amd/dev/ozhang/mhamlss

Conversation

@Zhaeong
Copy link
Copy Markdown
Contributor

@Zhaeong Zhaeong commented Apr 29, 2026

Motivation

AMDMLSS Integration for MHA and Convolution Acceleration, integration of two fusion types

Technical Details

  1. Multi-Head Attention Fusion
  • Matches attention subgraphs tagged as "attention" within "group" ops
  • Fuses into a single gpu::mlss_mha op
  • Fixed shape support: {1, 8, 4096, 40} (batch=1, 8 heads, seq_len=4096, head_dim=40)
  • Uses pre-compiled GFX1201 kernel: gfx1201_mha_64x64x48_64x48x64 with FP16 packed QKV format
  1. Convolution Fusion (3 matchers)
  • find_mlss_conv — basic convolution (FP32 + FP16)

  • find_mlss_conv_bias — conv + broadcast(bias) pattern

  • find_mlss_conv_bias_relu — relu(conv + bias) pattern, all fused into one kernel

  • MIGRAPHX_MLSS_USE_SPECIFIC_OPS env var selects which MLSS ops to enable (e.g., "mha,conv")

  • MIGRAPHX_HAS_MLSS_HEADERS compile flag gates the feature

    Supported shapes target ResNet-50 and VGG-19 workloads (3x3 convolutions at various spatial resolutions).

@Zhaeong Zhaeong marked this pull request as ready for review April 29, 2026 21:21
@Zhaeong Zhaeong requested a review from causten as a code owner April 29, 2026 21:21
@Zhaeong Zhaeong requested a review from eddieliao April 29, 2026 21:24

// namespace mlss::shaders::gqa::ck::wmma::fp16::gfx1201::rel
// {
constexpr ShaderType<51976> multi_head_attention_void_single_pointer_packed_qkv_128_64x64x48_64x48x64_forward_with_strides_fp16_gfx1201 = {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding the binary inside the header file, we should probably include the library through our build. This would be really difficult to maintain if it ever changes.

Probably something along the lines of this inside CMakeLists:

find_library(AMDMLSS_LIB amdmlss PATHS "${AMDMLSS_ROOT}/lib" REQUIRED)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was what was originally done in this PR:
#4303
But library isn't ready. The current plan is to include the amdmlss library into UAI tool and include its headers dir.

Comment thread src/targets/gpu/include/migraphx/gpu/fuse_mlss.hpp Outdated
Comment thread src/targets/gpu/include/migraphx/gpu/code_object_op.hpp Outdated
Comment thread src/targets/gpu/include/migraphx/gpu/fuse_mlss.hpp Outdated
Comment thread src/targets/gpu/include/migraphx/gpu/fuse_mlss.hpp Outdated
Comment thread src/targets/gpu/include/migraphx/gpu/fuse_mlss.hpp Outdated
Comment thread src/targets/gpu/include/migraphx/gpu/fuse_mlss.hpp Outdated
Comment thread src/targets/gpu/code_object_op.cpp
Comment thread src/targets/gpu/fuse_mlss.cpp Outdated
@eddieliao
Copy link
Copy Markdown
Contributor

On second thought, I don't believe this should go under JIT at all. Correct me if I'm wrong, but all of the MLSS kernels are pre-compiled and thus we should treat them like we do hipBLASLt, MIOpen, etc. The standalone header idea seems messy to me, but if we want to keep it I would at least make it a separate static library we can build. That way any work within MIGraphX can call to this "library" which can be replaced whenever the MLSS library is ready.

I think the current implementation would also cause MXRs to get large in size since there would need to be a copy of the kernel per MHA instruction instead of just having one instance of it.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 30, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #4831      +/-   ##
===========================================
+ Coverage    92.46%   92.85%   +0.39%     
===========================================
  Files          583      584       +1     
  Lines        29564    30147     +583     
===========================================
+ Hits         27336    27992     +656     
+ Misses        2228     2155      -73     

see 36 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Zhaeong and others added 8 commits April 30, 2026 00:02
Co-authored-by: Eddie Liao <54926923+eddieliao@users.noreply.github.com>
Co-authored-by: Eddie Liao <54926923+eddieliao@users.noreply.github.com>
Co-authored-by: Eddie Liao <54926923+eddieliao@users.noreply.github.com>
Co-authored-by: Eddie Liao <54926923+eddieliao@users.noreply.github.com>
Co-authored-by: Eddie Liao <54926923+eddieliao@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces an initial MLSS integration on the GPU target, adding a new fuse_mlss pass that can replace tagged attention group instructions with an MLSS MHA custom op on gfx1201, backed by an embedded precompiled kernel.

Changes:

  • Add a GPU fuse_mlss pass gated by MIGRAPHX_MLSS_USE_SPECIFIC_OPS and gfx1201.
  • Add a new gpu::mlss_mha operation that dispatches the embedded MLSS MHA kernel.
  • Add a GPU test intended to validate MLSS MHA fusion behavior.

Reviewed changes

Copilot reviewed 7 out of 8 changed files in this pull request and generated 10 comments.

Show a summary per file
File Description
test/gpu/fuse_mlss.cpp Adds a fusion test for MLSS MHA on gfx1201 (currently has correctness issues).
src/targets/gpu/target.cpp Wires fuse_mlss into the GPU pass pipeline.
src/targets/gpu/fuse_mlss.cpp Implements the MLSS fusion pass and env-var gating.
src/targets/gpu/include/migraphx/gpu/fuse_mlss.hpp Declares the MLSS fusion pass API and mlss_enabled().
src/targets/gpu/mlss_mha_op.cpp Implements the MLSS MHA op kernel dispatch.
src/targets/gpu/include/migraphx/gpu/mlss_mha_op.hpp Declares the MLSS MHA op and reflected attributes.
src/targets/gpu/CMakeLists.txt Adds new MLSS sources to the GPU library build.

auto inputs = ins->inputs();
if(inputs.size() != 3)
return;

Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The selected MLSS kernel name indicates a single-pointer packed QKV layout, and mlss_mha_op::compute() only uses args[0] as the base pointer. However this fusion only checks inputs.size()==3 and the Q shape, and then passes {Q,K,V} directly. If Q/K/V are not views into the same packed-QKV buffer (with Q at offset 0), the fused program will read incorrect memory and produce wrong results. Please add structural checks (e.g., Q/K/V must be slices of the same packed tensor with expected offsets/strides) or insert an explicit packing step; otherwise do not fuse.

Suggested change
// The selected MLSS kernel consumes a single base pointer to packed QKV.
// This pass currently sees three graph inputs and does not verify that
// they are views into the same packed buffer with the required offsets
// and strides, so this fusion is not safe here.
return;

Copilot uses AI. Check for mistakes.
Comment thread src/targets/gpu/fuse_mlss.cpp Outdated
Comment on lines +34 to +39
shape mlss_mha_op::compute_shape(std::vector<shape> inputs) const
{
// inputs: [Q/K/V packed, K, V, scale_literal, output_buffer]
// output_buffer is the last input — its shape must match the stored output shape
return output;
}
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compute_shape() ignores its inputs and returns the stored output without validating the number/types of inputs or that the provided output buffer shape matches. For safety (and to match patterns like gpu::code_object), this should validate expected input count/layout and ensure the last input (output buffer) is compatible with output, otherwise the kernel can silently misbehave.

Copilot uses AI. Check for mistakes.
Comment thread src/targets/gpu/mlss_mha_op.cpp Outdated
Comment thread test/gpu/fuse_mlss.cpp
Comment on lines +115 to +120
auto n = ins->name();
if(n == "mlss_mha")
found_mlss_mha = true;
if(n == "group")
found_group = true;
}
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPU ops in MIGraphX are typically named with the "gpu::" prefix (e.g., "gpu::dynamic_code_object_op"). mlss_mha_op::name() returns "gpu::mlss_mha", so this test's lookup for "mlss_mha" will not find the fused instruction.

Copilot uses AI. Check for mistakes.
Comment thread test/gpu/fuse_mlss.cpp
Comment on lines +132 to +139
// Inputs: Q, K, V, scale_literal
EXPECT(ins->inputs().size() == 4);

// 4th input must be a half scalar literal (the scale)
auto scale_in = ins->inputs()[3];
EXPECT(scale_in->name() == "@literal");
EXPECT(scale_in->get_shape().type() == migraphx::shape::half_type);

Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fused mlss_mha_op is built with an explicit output buffer input (allocate) in fuse_mlss.cpp, so the instruction will have 5 inputs: Q, K, V, scale literal, and output buffer. This test currently expects 4 inputs and doesn't account for the output buffer input index.

Suggested change
// Inputs: Q, K, V, scale_literal
EXPECT(ins->inputs().size() == 4);
// 4th input must be a half scalar literal (the scale)
auto scale_in = ins->inputs()[3];
EXPECT(scale_in->name() == "@literal");
EXPECT(scale_in->get_shape().type() == migraphx::shape::half_type);
// Inputs: Q, K, V, scale_literal, output_buffer
EXPECT(ins->inputs().size() == 5);
// 4th input must be a half scalar literal (the scale)
auto scale_in = ins->inputs()[3];
EXPECT(scale_in->name() == "@literal");
EXPECT(scale_in->get_shape().type() == migraphx::shape::half_type);
// 5th input must be the explicit output buffer
auto output_buffer = ins->inputs()[4];
EXPECT(output_buffer->get_shape() == qkv_shape);

Copilot uses AI. Check for mistakes.
Comment on lines +77 to +86
// Find the half-precision scale literal inside the submodule
instruction_ref scale_literal_ins = attn_mod->end();
for(auto sub_ins : iterator_for(*attn_mod))
{
if(sub_ins->name() == "@literal")
{
scale_literal_ins = sub_ins;
break;
}
}
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass currently selects the first @literal in the attention submodule as the scale. That is brittle if the submodule contains other literals (biases, constants, etc.) and can cause fusing with an incorrect scale. Consider matching a specific literal shape/type (e.g., half scalar) and/or verifying how it is used in the submodule before using it as the MHA scale.

Copilot uses AI. Check for mistakes.
Comment thread test/gpu/fuse_mlss.cpp
Comment on lines +104 to +105
const float scale_val = 1.0f / std::sqrt(40.0f);

Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::sqrt is used below but is not included in this file, which can cause a compile error on some toolchains. Add the proper header (or avoid std::sqrt) to make the test self-contained.

Copilot uses AI. Check for mistakes.
Comment thread test/gpu/fuse_mlss.cpp
Comment on lines +50 to +64
static void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::gpu::fuse_mlss{&get_context()}, migraphx::dead_code_elimination{}});
}

// Set MIGRAPHX_MLSS_USE_SPECIFIC_OPS=mha at static-init time, before any test runs.
// This must happen before the first call to string_value_of(), which caches its result.
const int mlss_env_init = ([] {
#ifdef _WIN32
_putenv_s("MIGRAPHX_MLSS_USE_SPECIFIC_OPS", "mha");
#else
setenv("MIGRAPHX_MLSS_USE_SPECIFIC_OPS", "mha", /*overwrite=*/1); // NOLINT(cert-env33-c)
#endif
}(), 0);

Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This static-init env var setter is an unused namespace-scope const (internal linkage), which can trigger -Wunused-const-variable, and it relies on static initialization order to happen before the first cached string_value_of() call. Prefer setting the env var in main() before test::run() (or mark it [[maybe_unused]]/MIGRAPHX_DEBUG_USED) to avoid ordering and warning issues.

Suggested change
static void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::gpu::fuse_mlss{&get_context()}, migraphx::dead_code_elimination{}});
}
// Set MIGRAPHX_MLSS_USE_SPECIFIC_OPS=mha at static-init time, before any test runs.
// This must happen before the first call to string_value_of(), which caches its result.
const int mlss_env_init = ([] {
#ifdef _WIN32
_putenv_s("MIGRAPHX_MLSS_USE_SPECIFIC_OPS", "mha");
#else
setenv("MIGRAPHX_MLSS_USE_SPECIFIC_OPS", "mha", /*overwrite=*/1); // NOLINT(cert-env33-c)
#endif
}(), 0);
static void set_mlss_use_specific_ops()
{
#ifdef _WIN32
_putenv_s("MIGRAPHX_MLSS_USE_SPECIFIC_OPS", "mha");
#else
setenv("MIGRAPHX_MLSS_USE_SPECIFIC_OPS", "mha", 1); // NOLINT(cert-env33-c)
#endif
}
static void run_pass(migraphx::program& p)
{
set_mlss_use_specific_ops();
migraphx::run_passes(p, {migraphx::gpu::fuse_mlss{&get_context()},
migraphx::dead_code_elimination{}});
}

Copilot uses AI. Check for mistakes.
Comment thread test/gpu/fuse_mlss.cpp
Comment on lines +160 to +164
for(auto ins : migraphx::iterator_for(*mm))
{
if(ins->name() == "mlss_mha")
found_mlss_mha = true;
}
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as above: the fused op name is "gpu::mlss_mha" (not "mlss_mha"), so this check will not detect fusion correctly.

Copilot uses AI. Check for mistakes.
@Zhaeong Zhaeong changed the title Addition of MLSS MHA for GFX1201 as standalone header Addition of MLSS initial kernels for GFX1201 May 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants