Fused Adam Support for MXFP8 + FSDP2 integration#2780
Fused Adam Support for MXFP8 + FSDP2 integration#2780vthumbe1503 wants to merge 8 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Greptile SummaryThis PR adds a fused Adam optimizer kernel for MXFP8 (MX block-scaling FP8) model weights, integrating it into the existing Key changes:
Issues found:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant PY as FusedAdam.step() (Python)
participant EXT as adam.cpp (PyTorch ext)
participant CU as adam.cu (CUDA)
participant APPLY as multi_tensor_apply_mxfp8
participant KERNEL as adam_mxfp8_fused_kernel
PY->>PY: per-param loop — classify params<br/>(Float8/MXFP8/F16/F32)
PY->>PY: accumulate into p_mxfp8_rowwise,<br/>p_mxfp8_colwise, moments, master_param
PY->>EXT: multi_tensor_adam_mxfp8(chunk_size,<br/>noop_flag, 8 tensor lists, …, fp8_dtype)
EXT->>EXT: makeTransformerEngineTensorList()<br/>validate num_lists == 8
EXT->>CU: nvte_multi_tensor_adam_mxfp8_cuda(…)
CU->>CU: compute_bias_correction()<br/>check_tensor_list_sizes()<br/>dtype validation
CU->>APPLY: multi_tensor_apply_mxfp8<kernel>(…)
loop For each tensor (batched ≤ MXFP8_MAX_TENSORS=24 tensors,<br/>≤ MXFP8_MAX_BLOCKS=320 blocks per launch)
APPLY->>APPLY: build MXFP8TensorListMetadata<br/>(block_to_tensor, block_to_tile, rows, cols)
APPLY->>KERNEL: Kernel<<<blocks, 256>>>(tl, β1, β2, ε, lr, …)
KERNEL->>KERNEL: Stage 4: Adam update → p/m/v (FP32)
KERNEL->>KERNEL: Stage 5: atomicMaxFloat → row/col amax (shared mem)
KERNEL->>KERNEL: Stage 6: write rowwise & colwise scale-inv (e8m0)
KERNEL->>KERNEL: Stage 7: quantise p → MXFP8 rowwise + colwise data
end
CU-->>PY: return (master params, moments, MXFP8 data, scales updated in-place)
Last reviewed commit: "address review comme..." |
…rmerEngine into fused_adam_for_mxfp8
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…rmerEngine into fused_adam_for_mxfp8
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
Need more perf tuning for mxfp8, Will make the PR active after the desired perf is achieved. |
| void multi_tensor_apply_mxfp8(int64_t chunk_size, const transformer_engine::Tensor &noop_flag, | ||
| std::vector<std::vector<transformer_engine::Tensor *>> tensor_lists, | ||
| uint8_t fp8_dtype, cudaStream_t stream, ArgTypes... args) { | ||
| constexpr size_t kNumTensorLists = 8; |
There was a problem hiding this comment.
Let’s move this parameter above struct MXFP8TensorListMetadata so we can use it to size the array. That way we can replace void* addresses[8][MXFP8_MAX_TENSORS]; with void* addresses[kNumTensorLists][MXFP8_MAX_TENSORS];
| uint8_t fp8_dtype, cudaStream_t stream, ArgTypes... args) { | ||
| constexpr size_t kNumTensorLists = 8; | ||
| NVTE_CHECK(tensor_lists.size() == kNumTensorLists, | ||
| "Expected 8 tensor lists for MXFP8, but found ", tensor_lists.size()); |
There was a problem hiding this comment.
Here we hard-coded the tensor-list size as 8. Let’s use kNumTensorLists instead
| const ::transformer_engine::e8m0_t row_biased = | ||
| reinterpret_cast<const ::transformer_engine::e8m0_t &>(row_raw); | ||
| const float row_scale_inv = transformer_engine::ptx::exp2f_rcp(row_biased); | ||
| if (dtype == static_cast<uint8_t>(transformer_engine::DType::kFloat8E4M3)) { |
There was a problem hiding this comment.
To improve performance, it may be worth adding function template parameters (e.g., IType, OType), as we do in other kernels, to avoid runtime branching.
Oh, I see. OType may vary across tensors within a multi-tensor call, so it’s a runtime attribute, right?
| return static_cast<FP8_T>(x); | ||
| } | ||
|
|
||
| __device__ __forceinline__ float fp8_max_norm_rcp(uint8_t fp8_dtype) { |
There was a problem hiding this comment.
Additionally, if we add OType, we can drop this helper and use transformer_engine::Quantized_Limits<OType>::max_norm_rcp directly
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: