Skip to content

Conversation

@dbsanfte
Copy link

@dbsanfte dbsanfte commented Jan 14, 2026

Summary

Fixes element operation type mismatch in GridwiseGemmDlMultipleD_km_kn_mn when FloatAcc != FloatC.

The Bug

When accumulator type differs from output type (e.g., INT8×INT8 → INT32 accumulate → FP32 output), the CDE element op is invoked with references to the wrong storage type.

The element op contract is: (E& e, const C& c, const D& d...) where:

  • E = FloatC (the final output type, e.g. float)
  • C = FloatAcc (the accumulator type, e.g. int32_t)

Current behavior (broken): The kernel builds dst_data_refs from c_thread_buf, which is StaticBuffer<FloatAcc>. This means the element op receives int32_t& for both E& and C&—violating its signature when FloatAcc != FloatC.

Why this is wrong:

  1. The element op is supposed to convert from accumulator precision to output precision: e = f(c, d...). If e and c alias the same storage, the conversion semantics are lost.
  2. Non-templated or strictly-typed element ops that expect float& for output fail at compile time.
  3. Even for templated element ops that happen to compile, the kernel later does ThreadwiseTensorSliceTransfer_v1r3<FloatAcc, FloatC, ...> on c_thread_buf—meaning it type-puns FloatAcc bits as FloatC, which is undefined behavior for non-trivially-convertible types.

Fixed behavior: Introduce separate e_thread_buf<FloatC> for element op output, pass (E& e) from this buffer and (const C& c) from c_thread_buf, then transfer e_thread_buf to global memory.

Context / Affected

  • GPU: gfx906 (MI50/MI60 class)
  • ROCm: 7.1.1
  • Use case: INT8×INT8→INT32 accumulate with FP32 output scaling via one or more D tensors (DeviceGemmMultipleD_Dl)

Minimal repro (verified)

This is a compile-time repro: no runtime execution or valid pointers needed.

Key point: the element op is intentionally non-templated and requires float& for the output ref.
If the kernel incorrectly passes an int32_t& (FloatAcc) as the output ref, compilation fails.

#include <array>

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

struct ScaleSingle_Strict
{
    __host__ __device__ void operator()(float& e, const int32_t& c, const float& d) const
    {
        e = ck::type_convert<float>(c) * d;
    }
};

using ADataType   = int8_t;
using BDataType   = int8_t;
using AccDataType = int32_t;            // FloatAcc
using DsDataType  = ck::Tuple<float>;   // One D tensor
using EDataType   = float;              // FloatC

using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;

using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Dl<
    Row,
    Col,
    ck::Tuple<Row>,
    Row,
    ADataType,
    BDataType,
    AccDataType,
    DsDataType,
    EDataType,
    PassThrough,
    PassThrough,
    ScaleSingle_Strict,
    ck::tensor_operation::device::GemmSpecialization::Default,
    // Parameter pack copied from CK example/14_gemm_quantization/gemm_dl_quantization_int8.cpp
    256,
    128,
    128,
    16,
    4,
    4,
    4,
    1,
    S<8, 2>,
    S<8, 2>,
    S<8, 1, 1, 4>,
    S<2, 1, 128, 1>,
    S<1, 2, 0, 3>,
    S<1, 2, 0, 3>,
    S<4, 1, 1, 4>,
    S<1, 2, 0, 3>,
    S<1, 1, 1, 4>,
    S<8, 1, 1, 4>,
    S<2, 1, 128, 1>,
    S<1, 2, 0, 3>,
    S<1, 2, 0, 3>,
    S<4, 1, 1, 4>,
    S<1, 2, 0, 3>,
    S<1, 1, 1, 4>,
    S<0, 1, 2, 3, 4, 5>,
    5,
    4>;

int main()
{
    DeviceGemmInstance gemm;

    int M = 128;
    int N = 128;
    int K = 128;

    std::array<const void*, 1> p_ds      = {nullptr};
    std::array<ck::index_t, 1> stride_ds = {N};

    auto arg = gemm.MakeArgument(
        /* A */ nullptr,
        /* B */ nullptr,
        /* Ds */ p_ds,
        /* E */ nullptr,
        M,
        N,
        K,
        /* StrideA */ K,
        /* StrideB */ N,
        stride_ds,
        /* StrideE */ N,
        PassThrough{},
        PassThrough{},
        ScaleSingle_Strict{});

    (void)gemm.IsSupportedArgument(arg);
    return 0;
}

How to reproduce (compile)

On gfx906 + ROCm 7.1.1:

# On upstream develop (expected: FAIL)
git clone https://github.com/ROCm/composable_kernel.git
cd composable_kernel
git checkout develop
/opt/rocm/bin/hipcc -I./include -std=c++17 -O2 --offload-arch=gfx906 /path/to/repro.cpp -o repro

# With this PR applied (expected: PASS)
git fetch origin pull/3565/head:pr3565
git checkout pr3565
/opt/rocm/bin/hipcc -I./include -std=c++17 -O2 --offload-arch=gfx906 /path/to/repro.cpp -o repro

This exploits point 2. from Why this is wrong in the bug description to demonstrate the bug at compile time.

Failing error line (upstream develop)

This is the first relevant error (line numbers may vary by commit):

.../include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp:632:33: error: no matching function for call to 'unpack2'

The diagnostic also shows dst_data_refs contains int& (FloatAcc) where the element op requires float& (FloatC).

@dbsanfte
Copy link
Author

dbsanfte commented Jan 14, 2026

I think there's still something wrong with the fused scaling as the cosine simularity to fp32 reference is still not matching a two-kernel approach (CK INT8 GEMM + separate scaling kernel).

With CK + Fused scaling

[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from ROCmQuantisedGemmPerf
[ RUN      ] ROCmQuantisedGemmPerf.Qwen7B_M128

╔══════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
║  ROCm INT8 GEMM: Qwen2.5-7B                                                                                ║
║  Device: AMD Instinct MI60 / MI50 (gfx906:sramecc+:xnack-)                                                 ║
╠══════════════════════════════════════════════════════════════════════════════════════════════════════════════╣
║  Workload                   | Dimensions                | Time (ms)            | Throughput           | Accuracy  ║
╠══════════════════════════════════════════════════════════════════════════════════════════════════════════════╣
Attn Output                  | M=128  N=3584   K=3584   | 1.710   ms (±0.001) | 1.923  TFLOPS (pk 1.924) | cos=0.991752
FFN Gate/Up                  | M=128  N=18944  K=3584   | 5.765   ms (±0.003) | 3.015  TFLOPS (pk 3.016) | cos=0.991771
FFN Down                     | M=128  N=3584   K=18944  | 6.630   ms (±0.002) | 2.622  TFLOPS (pk 2.623) | cos=0.998492
╚══════════════════════════════════════════════════════════════════════════════════════════════════════════════╝

With Two-Kernel approach (CK GEMM kernel + Scaling kernel)

[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from ROCmQuantisedGemmPerf
[ RUN      ] ROCmQuantisedGemmPerf.Qwen7B_M128

╔══════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
║  ROCm INT8 GEMM: Qwen2.5-7B                                                                                ║
║  Device: AMD Instinct MI60 / MI50 (gfx906:sramecc+:xnack-)                                                 ║
╠══════════════════════════════════════════════════════════════════════════════════════════════════════════════╣
║  Workload                   | Dimensions                | Time (ms)            | Throughput           | Accuracy  ║
╠══════════════════════════════════════════════════════════════════════════════════════════════════════════════╣
Attn Output                  | M=128  N=3584   K=3584   | 1.682   ms (±0.000) | 1.955  TFLOPS (pk 1.956) | cos=0.999956
FFN Gate/Up                  | M=128  N=18944  K=3584   | 5.659   ms (±0.011) | 3.071  TFLOPS (pk 3.083) | cos=0.999956
FFN Down                     | M=128  N=3584   K=18944  | 6.525   ms (±0.005) | 2.664  TFLOPS (pk 2.666) | cos=0.999947
╚══════════════════════════════════════════════════════════════════════════════════════════════════════════════╝

Cosine simularity much better in the latter.

@dbsanfte dbsanfte marked this pull request as draft January 14, 2026 18:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant