Skip to content

HipKittens MXFP8 GEMM Support#566

Open
alextmagro wants to merge 9 commits intodevfrom
hipkittens_mxfp8
Open

HipKittens MXFP8 GEMM Support#566
alextmagro wants to merge 9 commits intodevfrom
hipkittens_mxfp8

Conversation

@alextmagro
Copy link
Copy Markdown
Contributor

@alextmagro alextmagro commented Apr 28, 2026

Creates an MXFP8 GEMM with HipKittens that outperforms hipBLASlt, and offers additional epilogues such as BIAS and GELU AUX

Requires a workspace sized relative to the model. Often larger than hipBLASlt, but with significant performance improvements. Only builds for gfx950, and requires M / 256 and N / 256.

Adds hipKittens header library as a submodule.

Comment on lines +173 to +175
static size_t align256(size_t x) {
return (x + 255) & ~(size_t)255;
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: clever idea. consider consolidating with

inline size_t round_up_to_nearest_multiple(const size_t N, const size_t M) {
?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I have consolidated since we're not performance sensitive here.

Comment on lines +178 to +179
size_t k_iters = k / 128;
size_t scale_k = k / 32;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does those 128/32 come from? hipkitten kernel implementation details?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

K is from hipkittens kernel, 32 is the mxfp8 block size. I have added constexpr vars to help identify that.

Comment on lines +375 to +377
if (!params.force_hipblaslt && (params.m % 256 || params.n % 256 || params.k < 256)) {
GTEST_SKIP() << "HipKittens requires (M%256, N%256, K>=256)";
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoding 256 -> some macro or constexpr with more readable names?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the skip messages already document what the 256 refers to. I think we should keep the numbers as is for consistency with the hipblaslt skip condition.

workspace_size = 67108864;
}
if (use_mxfp8) {
workspace_size = compute_mxfp8_workspace_size(params.m, params.k, params.n,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this workspace for hipkitten enough for the hipblaslt flow?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, note in compute_mxfp8_workspace_size the final thing we do is return the max of the needed size and workspace_size, so we will always have at least 64 MiB for gfx950.

Comment thread tests/jax/utils.py
)
if use_bias:
pytest.skip("hipblaslt GEMM does not yet support MXFP8 with bias.")
hipkittens_eligible = (m % 256 == 0) and (n % 256 == 0) and (k >= 256)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same hardcoding 256s...

size_t mismatch_limit = use_mxfp8 ? std::max((size_t)1, params.m * params.n / 1'000'000) : 0;
RefD.to_cpu();
compareResults("D", D, RefD.rowwise_cpu_dptr<D_Type>(), true, atol, rtol);
compareResults("D", D, RefD.rowwise_cpu_dptr<D_Type>(), true, atol, rtol, true, mismatch_limit);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our output is mxfp8?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output is not mxfp8. Changing the seeds of the tests caused a few small mismatches due to the extra noise from MXFP8. This occurred for hipBLASlt, not hipKittens (although I assume there are seeds where this is the other way around).

Comment on lines +666 to +670
mxfp8_data_transpose<<<grid_tr, 1024, 0, stream>>>(
(const uint8_t *)A, a_tr, K, M);

dim3 grid_sc((M + 31) / 32, (scale_K + 31) / 32);
transpose_mxfp8_scales<<<grid_sc, 256, 0, stream>>>(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: naming inconsistent, mxfp8_data_transpose vs tranpose_mxfp8_scales

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

launch_pack_scales((const uint8_t *)scale_B, packed_sb, N, scale_K, k_iters, stream);

GemmEpilogue ep = select_epilogue(bias, aux_gelu);
dispatch_tn_gemm(ep, a_fp8_code, b_fp8_code,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Emm, if you convert other NN, NT to TN, then launch as tn layout, why not request TN directly from TE upstream during the cast transpose and modify the canonicalize gemm function?

Also I thought for mxfp8, transpose quantized data and scales does not give us rowwise -> columnwise conversion, right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your first question is answered by your second comment. We can't swap in rowwise for colwise, so we can't request TN directly from TE upstream.

To your second question, we're not swapping in rowwise for colwise here, or the other way around. We are transposing the relevant colwise or rowwise data and scales after quantization.

Comment on lines +107 to +121
key = (device, ub, grouped_gemm)
ws = _workspace_cache.get(key)
if ws is None:
ws = torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device)
_workspace_cache[key] = ws
return ws


def check_mxfp8_workspace(device: int, needed: int) -> None:
"""Grow the workspace to required size"""
key = (device, False, False)
ws = _workspace_cache.get(key)
if ws is not None and ws.shape[0] >= needed:
return
_workspace_cache[key] = torch.empty(needed, dtype=torch.uint8, device=device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have concerns for the proposed workspace cache system:
1). In non-moe runs, it will try to allocate the largest size kitten_gemm needs, replace previous allocated smaller buffers, relying on pytorch garbage collection to deallocate. Then the biggest single buffer will stay in the process starting from the second iteration.
2). For the MOE run, sizes are dynamic, so probably the cache system can still change after the warm up runs

If we can force TE upstream to always provide you TN layout, then we can remove this dynamic workspace entirely?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your concern, but think we are ok for current models.

1.) This is correct, we only keep the largest workspace, relying on pytorch GC to delete the old workspace. This only affects iteration 1.
2.) Since the workspace is shared for all GEMMs in the model, I think this is unlikely. For example, with DeepSeek 671B with BS=2, the largest non-MoE workspace needed is for the dense layers FFN, where wgrad GEMM will need 200 MB compared to the theoretically maximum MoE GEMM size of 72 MB so this wouldn't occur. For a full MoE Model like Qwen 235B, we still don't run into this issue as the largest non-MoE GEMM would use 96 MB vs 44 MB worst case for MoE.

It is possible that there is a model that exists or could exists where the MoE GEMM is the largest, but convergence theory would imply that we hit the maximum allocation threshold fairly quickly with a many-layer model, and it almost certainly wouldn't affect the performance of a full training run.



@functools.lru_cache(maxsize=None)
def _hipkittens_workspace_bytes(m: int, n: int, k: int, layout: str) -> int:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also you can make it a common api in c, then let both jax/pytorch extension to call it to determine the workspace buffer

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I have added it to our kittens directory, with a pybind hook for JAX and pytorch.

@alextmagro alextmagro requested a review from wangye805 May 5, 2026 20:26
kittens::zero(cA); kittens::zero(cB); kittens::zero(cC); kittens::zero(cD);

const int NUM_XCDS = 8;
const int WGM = 8;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering was WGM = 8 tuned for this kernel?

Comment on lines +306 to +341
if constexpr (HAS_BIAS) {
int m_base_lo = block_m + warp_m * REG_M;
int m_base_hi = block_m + (WARPS_ROW + warp_m) * REG_M;
int lane = kittens::laneid();
int row_off = cA.base_tile_stride * (lane / cA.base_tile_cols);

#pragma unroll
for (int i = 0; i < cA.height; i++) {
#pragma unroll
for (int j = 0; j < cA.width; j++) {
#pragma unroll
for (int k = 0; k < cA.base_tile_num_strides; k++) {
#pragma unroll
for (int l = 0; l < cA.base_tile_stride / 2; l++) {
int idx = l + k * cA.base_tile_stride / 2;
int m_lo_x = m_base_lo + i * 16 + row_off + l * 2;
int m_lo_y = m_lo_x + 1;
int m_hi_x = m_base_hi + i * 16 + row_off + l * 2;
int m_hi_y = m_hi_x + 1;
float b_lo_x = read_bias(bias, bias_dtype, m_lo_x);
float b_lo_y = read_bias(bias, bias_dtype, m_lo_y);
float b_hi_x = read_bias(bias, bias_dtype, m_hi_x);
float b_hi_y = read_bias(bias, bias_dtype, m_hi_y);
cA.tiles[i][j].data[idx].x += b_lo_x;
cA.tiles[i][j].data[idx].y += b_lo_y;
cB.tiles[i][j].data[idx].x += b_lo_x;
cB.tiles[i][j].data[idx].y += b_lo_y;
cC.tiles[i][j].data[idx].x += b_hi_x;
cC.tiles[i][j].data[idx].y += b_hi_y;
cD.tiles[i][j].data[idx].x += b_hi_x;
cD.tiles[i][j].data[idx].y += b_hi_y;
}
}
}
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment: would it make sense to factor the bias-add block into an inline helper? It might make the fused epilogue logic easier to scan.

Comment on lines +811 to +821
} else if (!transa && !transb) {
return mxfp8_gemm_nn(A, B, C, scale_A, scale_B, M, N, K,
a_fp8, b_fp8, bias, bias_dc,
aux_gelu, out_dc, aux_dc,
workspace, workspace_size, stream);
} else if (!transa && transb) {
return mxfp8_gemm_nt(A, B, C, scale_A, scale_B, M, N, K,
a_fp8, b_fp8, bias, bias_dc,
aux_gelu, out_dc, aux_dc,
workspace, workspace_size, stream);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was the transpose-to-TN strategy benchmarked against native NN/NT handling, assuming native NN/NT is available? I’m wondering about the overhead from the data and scale transposes for realistic problem sizes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, do we expect to ever need support for the TT case?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 1 CI test level 1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants