Fix GridwiseGemmDlMultipleD element op for FloatAcc!=FloatC #3565
+23
−6
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
Fixes element operation type mismatch in
GridwiseGemmDlMultipleD_km_kn_mnwhenFloatAcc != FloatC.The Bug
When accumulator type differs from output type (e.g., INT8×INT8 → INT32 accumulate → FP32 output), the CDE element op is invoked with references to the wrong storage type.
The element op contract is:
(E& e, const C& c, const D& d...)where:E=FloatC(the final output type, e.g.float)C=FloatAcc(the accumulator type, e.g.int32_t)Current behavior (broken): The kernel builds
dst_data_refsfromc_thread_buf, which isStaticBuffer<FloatAcc>. This means the element op receivesint32_t&for bothE&andC&—violating its signature whenFloatAcc != FloatC.Why this is wrong:
e = f(c, d...). Ifeandcalias the same storage, the conversion semantics are lost.float&for output fail at compile time.ThreadwiseTensorSliceTransfer_v1r3<FloatAcc, FloatC, ...>onc_thread_buf—meaning it type-punsFloatAccbits asFloatC, which is undefined behavior for non-trivially-convertible types.Fixed behavior: Introduce separate
e_thread_buf<FloatC>for element op output, pass(E& e)from this buffer and(const C& c)fromc_thread_buf, then transfere_thread_bufto global memory.Context / Affected
DeviceGemmMultipleD_Dl)Minimal repro (verified)
This is a compile-time repro: no runtime execution or valid pointers needed.
Key point: the element op is intentionally non-templated and requires
float&for the output ref.If the kernel incorrectly passes an
int32_t&(FloatAcc) as the output ref, compilation fails.How to reproduce (compile)
On gfx906 + ROCm 7.1.1:
This exploits point
2.from Why this is wrong in the bug description to demonstrate the bug at compile time.Failing error line (upstream develop)
This is the first relevant error (line numbers may vary by commit):
The diagnostic also shows
dst_data_refscontainsint&(FloatAcc) where the element op requiresfloat&(FloatC).