[WIP] TDM porting#558
Draft
wangye805 wants to merge 48 commits intonpi_gfx1250from
Draft
Conversation
wangye805
commented
Apr 22, 2026
Collaborator
Author
wangye805
left a comment
There was a problem hiding this comment.
As I mentioned, also port TDM to NV upstream's flow which were guarded previously because we don't have TMA.
wangye805
commented
Apr 23, 2026
Collaborator
Author
wangye805
left a comment
There was a problem hiding this comment.
Please forget about previous rocm specific cast transpose kernel logics. Here I wanted you to closely follow NV upstream's behavior and do a TDM port as their TMA equivalent
wangye805
commented
Apr 23, 2026
…ory comments - Remove input_act_stride/output_stride as kernel params in gated kernels; compute them inside the kernel from cols and IS_DGATED template param - Add comments explaining why TDM does not need in_transaction_size (uses s_wait_tensorcnt counting ops, not mbarrier counting bytes) Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
wangye805
commented
Apr 23, 2026
…ests
- Fix `) {` placement to minimize diff in gated kernel signatures
- Fix MXFP8 gated kernel: remove unnecessary pre-loop wait, make
in-loop wait conditional to preserve double-buffering prefetch
- Add comments explaining TDM does not need mbarrier destroy
- Add NVTE_USE_NV_UPSTREAM_FLOW=1 ctest run in ci/core.sh to exercise
TDM kernel paths for MXFP8 quantize, gated, and dequantize
Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
When a double-buffered prefetch tile origin falls past the tensor boundary (non-tile-aligned rows/cols), tensor_h - tile_row and tensor_w - tile_col would underflow as uint32_t to ~4 billion, causing the TDM hardware to attempt a DMA of billions of rows and trigger a GPU page fault. Clamp the remaining extent to 0 when tile_row >= tensor_h or tile_col >= tensor_w. Unlike NV TMA (which encodes full tensor shape in a host-side CUtensorMap and clamps automatically), TDM computes the remaining extent per-call, so the caller must guard against out-of-bounds origins. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Introduce HIPTensorMap/HIPTensorMapOut structs in tdm.cuh as the AMD analog of CUtensorMap. Callers in cast_kernels.cuh and cast_gated_kernels.cuh now construct one descriptor per tensor at kernel entry and pass it to TDM helper calls instead of repeating 6+ raw scalars at every call site. Revert TDM usage in rocm_cast_kernels.cuh, rocm_cast_gated_kernels.cuh, and rocm_dequantize_kernels.cuh back to the original HIP vectorized copy_2d_to_shared / bulk_tensor_2d_shared_to_global path. The rocm_* kernels are the legacy non-TDM path; TDM is used only in the NV-upstream ported kernels (cast_kernels.cuh / cast_gated_kernels.cuh) behind NVTE_USE_NV_UPSTREAM_FLOW. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Remove the tdm.cuh include, TDM_SHMEM_ALIGNMENT usage, and any whitespace changes introduced in the previous commit, so rocm_cast_kernels.cuh, rocm_cast_gated_kernels.cuh, and rocm_dequantize_kernels.cuh are byte-for-byte identical to 5e8d61e (Ilya's branch point). Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…ain kernel functions On AMD, each mxfp8 quantize/dequantize/gated function previously dispatched between TDM and ROCm kernels via an inline env-var check. This refactor separates the two flows cleanly: - cast_gated_kernels.cuh / rocm_cast_gated_kernels.cuh: rocm_cast_mxfp8_gated() hosts the ROCm HIP gated kernel dispatch. cast_mxfp8_gated() is now TDM-only on AMD. quantize_gated() dispatches via NVTE_USE_NV_UPSTREAM_FLOW env var. - cast_kernels.cuh / rocm_cast_kernels.cuh: rocm_mxfp8_quantize() hosts the ROCm HIP cast kernel dispatch. mxfp8_quantize() is now TDM-only on AMD. fp8_quantize_rocm() dispatches via NVTE_USE_NV_UPSTREAM_FLOW env var. - dequantize_kernels.cuh / rocm_dequantize_kernels.cuh: rocm_mxfp8_dequantize() hosts the ROCm HIP dequantize dispatch. mxfp8_dequantize() is now TDM-only on AMD. dequantize_helper() dispatches via NVTE_USE_NV_UPSTREAM_FLOW env var. NV upstream path (no AMD) is unchanged throughout. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
wangye805
commented
Apr 26, 2026
- Rename namespace nv_flow -> tma_flow (more accurate: both TMA on NV
and TDM on AMD use this path)
- Rename env-var NVTE_USE_NV_UPSTREAM_FLOW -> NVTE_USE_TDM_FLOW with
inverted default: 0 = ROCm flow (default), 1 = TDM flow
- Apply same env-var dispatch to fp8 gated path (was missing)
- Remove dead AMD-specific guards around ScalingType, BUFF_DIM,
blocks, THREADS_PER_CHUNK, grid, block_size in cast_mxfp8_gated
- Remove AMD-specific {} wrapper and duplicate shmem computation block;
TMA_SHMEM_ALIGNMENT == TDM_SHMEM_ALIGNMENT == 128 so NV upstream
formula works on both platforms
Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Move swizzled_group_idx/swizzled_idx/shmem_offset_rowwise back to just before out_act.store_to(), matching NV upstream cast_gated_kernels.cuh lines 831-834, to minimize diff. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
wangye805
commented
Apr 26, 2026
wangye805
commented
Apr 26, 2026
Two differences found vs NV upstream (line 968):
1. out_gate_mem: AMD TDM kernel always needs a gate shmem buffer
regardless of IS_DGATED (kernel signature always includes gate
output pointers), so restore AMD-specific:
out_gate_mem = buff_size_aligned_out (always)
vs NV:
out_gate_mem = IS_DGATED ? buff_size_aligned_out : 0
2. in_mem: split into in_act_mem + in_gate_mem intermediate vars
to match NV upstream style exactly.
3. AMD TDM dispatch: restore TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH
dispatch (was accidentally dropped when removing the {} wrapper),
guarded under #ifdef __HIP_PLATFORM_AMD__. NV uses switch(scaling_type).
Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Replace TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH with the same switch(scaling_type) structure as NV upstream. The TDM kernel shares the same ROWWISE_SCALING/COLWISE_SCALING/THREADS_PER_CHUNK template params as the NV kernel — SCALE_DIM_Y/X/IS_ALIGNED were ROCm-flow params that don't apply here. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…ated_kernel next_buff, next_stage_offset_Y, global_offset_Y, global_offset_X, next_buff_offset are identical in both the TMA and TDM branches — declare them once above the #ifndef __HIP_PLATFORM_AMD__ guard. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
The shmem size calculation is identical for TDM and TMA paths (TDM_SHMEM_ALIGNMENT == TMA_SHMEM_ALIGNMENT == 128), so declare it once above the #ifdef __HIP_PLATFORM_AMD__ guard. Only the pointer setup and kernel launch remain platform-specific. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Both AMD and NV call mxfp8_kernel::cast_mxfp8_gated_kernel — the TDM kernel at line 435 is also inside namespace mxfp8_kernel. The two switch blocks were identical except for the namespace qualifier, so remove the #ifdef and keep one unified switch block. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…tion - common.h: add TMA_SHMEM_ALIGNMENT as alias for TDM_SHMEM_ALIGNMENT in AMD block so cast_gated_kernels.cuh launcher code compiles without ifdefs - rocm_cast_gated_kernels.cuh: define sigmoidf device inline since HIP runtime does not provide it (mirrors the CUDA definition in cast_gated_kernels.cuh) Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…e_arch_ge_100 call fp8_quantize_arch_ge_100 is guarded by #ifndef __HIP_PLATFORM_AMD__ (NV TMA only). AMD branch should delegate entirely to fp8_quantize_rocm, which internally dispatches to mxfp8_quantize (TDM path) or rocm_mxfp8_quantize based on NVTE_USE_TDM_FLOW. Also rename NVTE_USE_NV_UPSTREAM_FLOW to NVTE_USE_TDM_FLOW in rocm_cast_kernels.cuh to match the unified env var. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
The AMD section of mxfp8_quantize only sets up raw pointers and never launches a kernel, so calling it directly from quantize_helper left the scale buffer zero-initialized. fp8_quantize_rocm already has the correct TDM/plain-ROCm dispatch logic; route AMD through it instead. Fixes 1110 FusedCastMXFP8TestSuite failures on gfx950 (NVTE_USE_TDM_FLOW=0).
The rowwise scale tensor is allocated with stride padded to scale_tensor_alignment_X_rowwise (4), but rocm_mxfp8_dequantize was computing scales_stride = DIVUP(cols, 32) (unpadded). From row 1 onward the kernel reads the wrong scale, producing inf/garbage output. Fix: use DIVUP_TO_MULTIPLE(..., scale_tensor_alignment_X_rowwise), matching the allocation in the test harness and the NV dequantize path. Fixes 6 DequantizeMXFP8TestSuite failures (65x96, block_size=(1,32)) on gfx950.
The NVTE_USE_TDM_FLOW=1 branches in rocm_cast_kernels.cuh, cast_gated_kernels.cuh, and dequantize_kernels.cuh called TDM/TMA kernel paths (mxfp8_quantize, cast_mxfp8_gated, mxfp8_dequantize) that are no-ops on non-gfx1250 AMD — their device code is wrapped in #if defined(__gfx1250__) so nothing executes, leaving scales at zero. Wrap the TDM flow selection in #if defined(__HIP_PLATFORM_AMD__) && defined(__gfx1250__), falling back to the plain ROCm kernels (rocm_mxfp8_quantize, rocm_cast_mxfp8_gated, rocm_mxfp8_dequantize) on all other AMD architectures. Fixes all 2748 tests passing with NVTE_USE_TDM_FLOW=1 on gfx950.
The AMD section of mxfp8_quantize set up raw pointers but never launched the kernel — all TDM quantize calls were silent no-ops. Add the kernel launch switch (ROWWISE/COLWISE/BIDIMENSIONAL) mirroring the NV path, using raw pointers and TDM shared-memory sizing. Also fix host-side TDM dispatch guards: replace device-only __gfx1250__ with CMake-injected NVTE_ARCH_HAS_TDM (visible to host compilation) plus a runtime cuda::sm_arch_name() check, matching the ARCH_HAS_STOCHASTIC_ROUNDING pattern from PR #472. This ensures gfx942/950-only builds compile cleanly and multi-arch builds running on non-gfx1250 hardware fall back to the ROCm path even when NVTE_USE_TDM_FLOW=1. Add debug fprintf/printf traces across all dispatch and kernel entry points to confirm which code path executes at runtime. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
The switch-case for cast_mxfp8_2D_kernel is identical on AMD (TDM) and NV (TMA) — only the first four args differ (raw pointers vs CUtensorMap). Move the shared dshmem sizing and switch-case after the #ifdef block so there is a single launch path. The #ifdef now only covers platform-specific setup: raw pointer casts on AMD, create_2D_tensor_map descriptors on NV. TMA_SHMEM_ALIGNMENT is aliased to TDM_SHMEM_ALIGNMENT (both 128) so the shmem calculation is correct on both platforms without a separate formula. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…or TDM path - Move `using namespace mxfp8_kernel` outside `#ifndef __HIP_PLATFORM_AMD__` so tiling constants (CHUNK_DIM_Y/X, SCALE_DIM_X, BUFFS_NUM) are in scope on AMD - Guard all three `cudaFuncSetAttribute` calls with `#ifndef __HIP_PLATFORM_AMD__` since HIP cannot take the address of a templated kernel function the same way; dynamic shmem size is still correctly passed via <<<grid, block, dshmem, stream>>> - Add `__device__ __forceinline__` overloads of `__habs` and `__hmax` for `hip_bfloat16` (TE's bf16 alias) because ROCm only defines them for `__hip_bfloat16`, a distinct type on this ROCm version Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…atements Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
6433819 to
573f6ea
Compare
Newer ROCm clang (required for gfx1250) classifies __COUNTER__ as a C2y extension and emits -Wc2y-extensions. Combined with benchmark's own -pedantic-errors -Werror flags this causes a build failure. Also suppress -Wunused-const-variable for benchmark.h compiled as a standalone TU. Also fetch googletest v1.14.0 via FetchContent since test_common.hip (included via benchmark_utils.h) depends on gtest/gtest.h.
9bce604 to
9f340a1
Compare
wangye805
commented
Apr 30, 2026
wangye805
commented
Apr 30, 2026
Comment on lines
+560
to
+567
| mxfp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(input, act_input, noop, output, | ||
| dbias, workspace, stream); | ||
| } else { | ||
| fprintf(stderr, "[DBG fp8_quantize_rocm] gfx1250 ROCm branch -> rocm_mxfp8_quantize\n"); | ||
| rocm_mxfp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(input, act_input, noop, output, | ||
| rocm_mxfp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(input, act_input, noop, output, | ||
| dbias, workspace, stream); | ||
| } | ||
| #else | ||
| fprintf(stderr, "[DBG fp8_quantize_rocm] non-gfx1250 AMD -> rocm_mxfp8_quantize\n"); | ||
| rocm_mxfp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(input, act_input, noop, output, | ||
| rocm_mxfp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(input, act_input, noop, output, |
Collaborator
Author
There was a problem hiding this comment.
Those indents are not correct
Collaborator
Author
There was a problem hiding this comment.
Fixed in 1afc5b2 — restored proper indentation.
Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…verlap wait_tensorcnt_0() after a TDM store was draining the prefetched next-iteration load as well (TENSORcnt is a unified counter for loads and stores), eliminating any compute/transfer overlap. Use wait_tensorcnt_1() to drain only the store while keeping the in-flight prefetch alive, matching the intent of the double-buffer pipeline. The last iteration has no prefetch so still drains to 0. cast_gated_kernels.cuh already uses the correct pattern (stores are drained at the top of the next iteration rather than immediately), so no change needed. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Add a template wait_tensorcnt<N>() to tdm.cuh so callers express the
in-flight count in terms of their own PREFETCH_BUFFERS_NUM constant
rather than a hardcoded literal. Existing wait_tensorcnt_{0..4}() helpers
are kept as thin wrappers for call sites that don't need parameterization.
Also add PREFETCH_BUFFERS_NUM = 1 to dequantize_kernels.cuh (with a
static_assert against BUFFERS_NUM) to match the pattern in cast_kernels.cuh,
so the store-wait count is tied to a named constant in both files.
Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Replace hardcoded if constexpr (IS_DGATED) { wait_tensorcnt_3() } else
{ wait_tensorcnt_2() } with a named constexpr and the template helper in
both the FP8 and MX kernel loops. The count (3 for IS_DGATED, 2 otherwise)
is now self-documenting and will stay correct if the prefetch load count
ever changes.
Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
TDM instructions are wave-level (EXEC mask ignored by hardware), so issuing from thread 0 alone is sufficient. Rename the guard to is_tdm_lane() to reflect that it selects a single lane, not an entire wave. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Do not review, I'm just trying to play with claude for the iterations. Tracking https://github.com/ROCm/frameworks-internal/issues/16226
Fixes # (issue)
Type of change
Changes
Initial TDM porting
Checklist: