HipKittens MXFP8 GEMM Support#566
Conversation
| static size_t align256(size_t x) { | ||
| return (x + 255) & ~(size_t)255; | ||
| } |
There was a problem hiding this comment.
nit: clever idea. consider consolidating with
TransformerEngine/tests/cpp/test_common.h
Line 379 in 86438dc
There was a problem hiding this comment.
Good idea, I have consolidated since we're not performance sensitive here.
| size_t k_iters = k / 128; | ||
| size_t scale_k = k / 32; |
There was a problem hiding this comment.
Where does those 128/32 come from? hipkitten kernel implementation details?
There was a problem hiding this comment.
K is from hipkittens kernel, 32 is the mxfp8 block size. I have added constexpr vars to help identify that.
| if (!params.force_hipblaslt && (params.m % 256 || params.n % 256 || params.k < 256)) { | ||
| GTEST_SKIP() << "HipKittens requires (M%256, N%256, K>=256)"; | ||
| } |
There was a problem hiding this comment.
Hardcoding 256 -> some macro or constexpr with more readable names?
There was a problem hiding this comment.
I think the skip messages already document what the 256 refers to. I think we should keep the numbers as is for consistency with the hipblaslt skip condition.
| workspace_size = 67108864; | ||
| } | ||
| if (use_mxfp8) { | ||
| workspace_size = compute_mxfp8_workspace_size(params.m, params.k, params.n, |
There was a problem hiding this comment.
Will this workspace for hipkitten enough for the hipblaslt flow?
There was a problem hiding this comment.
Yes, note in compute_mxfp8_workspace_size the final thing we do is return the max of the needed size and workspace_size, so we will always have at least 64 MiB for gfx950.
| ) | ||
| if use_bias: | ||
| pytest.skip("hipblaslt GEMM does not yet support MXFP8 with bias.") | ||
| hipkittens_eligible = (m % 256 == 0) and (n % 256 == 0) and (k >= 256) |
| size_t mismatch_limit = use_mxfp8 ? std::max((size_t)1, params.m * params.n / 1'000'000) : 0; | ||
| RefD.to_cpu(); | ||
| compareResults("D", D, RefD.rowwise_cpu_dptr<D_Type>(), true, atol, rtol); | ||
| compareResults("D", D, RefD.rowwise_cpu_dptr<D_Type>(), true, atol, rtol, true, mismatch_limit); |
There was a problem hiding this comment.
The output is not mxfp8. Changing the seeds of the tests caused a few small mismatches due to the extra noise from MXFP8. This occurred for hipBLASlt, not hipKittens (although I assume there are seeds where this is the other way around).
| mxfp8_data_transpose<<<grid_tr, 1024, 0, stream>>>( | ||
| (const uint8_t *)A, a_tr, K, M); | ||
|
|
||
| dim3 grid_sc((M + 31) / 32, (scale_K + 31) / 32); | ||
| transpose_mxfp8_scales<<<grid_sc, 256, 0, stream>>>( |
There was a problem hiding this comment.
nit: naming inconsistent, mxfp8_data_transpose vs tranpose_mxfp8_scales
| launch_pack_scales((const uint8_t *)scale_B, packed_sb, N, scale_K, k_iters, stream); | ||
|
|
||
| GemmEpilogue ep = select_epilogue(bias, aux_gelu); | ||
| dispatch_tn_gemm(ep, a_fp8_code, b_fp8_code, |
There was a problem hiding this comment.
Emm, if you convert other NN, NT to TN, then launch as tn layout, why not request TN directly from TE upstream during the cast transpose and modify the canonicalize gemm function?
Also I thought for mxfp8, transpose quantized data and scales does not give us rowwise -> columnwise conversion, right?
There was a problem hiding this comment.
I think your first question is answered by your second comment. We can't swap in rowwise for colwise, so we can't request TN directly from TE upstream.
To your second question, we're not swapping in rowwise for colwise here, or the other way around. We are transposing the relevant colwise or rowwise data and scales after quantization.
| key = (device, ub, grouped_gemm) | ||
| ws = _workspace_cache.get(key) | ||
| if ws is None: | ||
| ws = torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device) | ||
| _workspace_cache[key] = ws | ||
| return ws | ||
|
|
||
|
|
||
| def check_mxfp8_workspace(device: int, needed: int) -> None: | ||
| """Grow the workspace to required size""" | ||
| key = (device, False, False) | ||
| ws = _workspace_cache.get(key) | ||
| if ws is not None and ws.shape[0] >= needed: | ||
| return | ||
| _workspace_cache[key] = torch.empty(needed, dtype=torch.uint8, device=device) |
There was a problem hiding this comment.
I have concerns for the proposed workspace cache system:
1). In non-moe runs, it will try to allocate the largest size kitten_gemm needs, replace previous allocated smaller buffers, relying on pytorch garbage collection to deallocate. Then the biggest single buffer will stay in the process starting from the second iteration.
2). For the MOE run, sizes are dynamic, so probably the cache system can still change after the warm up runs
If we can force TE upstream to always provide you TN layout, then we can remove this dynamic workspace entirely?
There was a problem hiding this comment.
I understand your concern, but think we are ok for current models.
1.) This is correct, we only keep the largest workspace, relying on pytorch GC to delete the old workspace. This only affects iteration 1.
2.) Since the workspace is shared for all GEMMs in the model, I think this is unlikely. For example, with DeepSeek 671B with BS=2, the largest non-MoE workspace needed is for the dense layers FFN, where wgrad GEMM will need 200 MB compared to the theoretically maximum MoE GEMM size of 72 MB so this wouldn't occur. For a full MoE Model like Qwen 235B, we still don't run into this issue as the largest non-MoE GEMM would use 96 MB vs 44 MB worst case for MoE.
It is possible that there is a model that exists or could exists where the MoE GEMM is the largest, but convergence theory would imply that we hit the maximum allocation threshold fairly quickly with a many-layer model, and it almost certainly wouldn't affect the performance of a full training run.
|
|
||
|
|
||
| @functools.lru_cache(maxsize=None) | ||
| def _hipkittens_workspace_bytes(m: int, n: int, k: int, layout: str) -> int: |
There was a problem hiding this comment.
Also you can make it a common api in c, then let both jax/pytorch extension to call it to determine the workspace buffer
There was a problem hiding this comment.
Good idea. I have added it to our kittens directory, with a pybind hook for JAX and pytorch.
| kittens::zero(cA); kittens::zero(cB); kittens::zero(cC); kittens::zero(cD); | ||
|
|
||
| const int NUM_XCDS = 8; | ||
| const int WGM = 8; |
There was a problem hiding this comment.
Just wondering was WGM = 8 tuned for this kernel?
| if constexpr (HAS_BIAS) { | ||
| int m_base_lo = block_m + warp_m * REG_M; | ||
| int m_base_hi = block_m + (WARPS_ROW + warp_m) * REG_M; | ||
| int lane = kittens::laneid(); | ||
| int row_off = cA.base_tile_stride * (lane / cA.base_tile_cols); | ||
|
|
||
| #pragma unroll | ||
| for (int i = 0; i < cA.height; i++) { | ||
| #pragma unroll | ||
| for (int j = 0; j < cA.width; j++) { | ||
| #pragma unroll | ||
| for (int k = 0; k < cA.base_tile_num_strides; k++) { | ||
| #pragma unroll | ||
| for (int l = 0; l < cA.base_tile_stride / 2; l++) { | ||
| int idx = l + k * cA.base_tile_stride / 2; | ||
| int m_lo_x = m_base_lo + i * 16 + row_off + l * 2; | ||
| int m_lo_y = m_lo_x + 1; | ||
| int m_hi_x = m_base_hi + i * 16 + row_off + l * 2; | ||
| int m_hi_y = m_hi_x + 1; | ||
| float b_lo_x = read_bias(bias, bias_dtype, m_lo_x); | ||
| float b_lo_y = read_bias(bias, bias_dtype, m_lo_y); | ||
| float b_hi_x = read_bias(bias, bias_dtype, m_hi_x); | ||
| float b_hi_y = read_bias(bias, bias_dtype, m_hi_y); | ||
| cA.tiles[i][j].data[idx].x += b_lo_x; | ||
| cA.tiles[i][j].data[idx].y += b_lo_y; | ||
| cB.tiles[i][j].data[idx].x += b_lo_x; | ||
| cB.tiles[i][j].data[idx].y += b_lo_y; | ||
| cC.tiles[i][j].data[idx].x += b_hi_x; | ||
| cC.tiles[i][j].data[idx].y += b_hi_y; | ||
| cD.tiles[i][j].data[idx].x += b_hi_x; | ||
| cD.tiles[i][j].data[idx].y += b_hi_y; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Minor comment: would it make sense to factor the bias-add block into an inline helper? It might make the fused epilogue logic easier to scan.
| } else if (!transa && !transb) { | ||
| return mxfp8_gemm_nn(A, B, C, scale_A, scale_B, M, N, K, | ||
| a_fp8, b_fp8, bias, bias_dc, | ||
| aux_gelu, out_dc, aux_dc, | ||
| workspace, workspace_size, stream); | ||
| } else if (!transa && transb) { | ||
| return mxfp8_gemm_nt(A, B, C, scale_A, scale_B, M, N, K, | ||
| a_fp8, b_fp8, bias, bias_dc, | ||
| aux_gelu, out_dc, aux_dc, | ||
| workspace, workspace_size, stream); | ||
| } |
There was a problem hiding this comment.
Was the transpose-to-TN strategy benchmarked against native NN/NT handling, assuming native NN/NT is available? I’m wondering about the overhead from the data and scale transposes for realistic problem sizes.
There was a problem hiding this comment.
Also, do we expect to ever need support for the TT case?
Creates an MXFP8 GEMM with HipKittens that outperforms hipBLASlt, and offers additional epilogues such as BIAS and GELU AUX
Requires a workspace sized relative to the model. Often larger than hipBLASlt, but with significant performance improvements. Only builds for gfx950, and requires M / 256 and N / 256.
Adds hipKittens header library as a submodule.