Skip to content

[Common, pyTorch] Grouped MXFP8 dequantize support#2722

Open
ptrendx wants to merge 14 commits intoNVIDIA:mainfrom
ptrendx:pr_grouped_dequantize
Open

[Common, pyTorch] Grouped MXFP8 dequantize support#2722
ptrendx wants to merge 14 commits intoNVIDIA:mainfrom
ptrendx:pr_grouped_dequantize

Conversation

@ptrendx
Copy link
Member

@ptrendx ptrendx commented Mar 2, 2026

Description

Support dequantization for MXFP8 grouped tensors.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Grouped dequantization kernel for MXFP8
  • Exposed the functionality in PyTorch

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx ptrendx requested a review from Oleg-Goncharov March 2, 2026 19:13
pre-commit-ci bot and others added 3 commits March 2, 2026 19:19
@ptrendx ptrendx linked an issue Mar 2, 2026 that may be closed by this pull request
ptrendx added 3 commits March 3, 2026 13:46
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx ptrendx marked this pull request as ready for review March 10, 2026 18:00
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 10, 2026

Greptile Summary

This PR adds grouped MXFP8 dequantization support, introducing a new Blackwell-only TMA-based CUDA kernel (group_dequantize_mxfp8_kernel), a C API entry point (nvte_group_dequantize), and a PyTorch extension (group_dequantize). The kernel correctly handles all four ShapeRepresentation modes (SAME_BOTH_DIMS, VARYING_FIRST_DIM, VARYING_LAST_DIM, VARYING_BOTH_DIMS) using double-buffered TMA loads and per-tensor scale-offset arithmetic that is mathematically consistent with the quantize path.

Key findings:

  • Scale offset arithmetic verified: The padded_rows / SCALE_DIM_X formula in the VARYING_LAST_DIM scale base-offset computation is provably equivalent to DIVUP_TO_MULTIPLE(DIVUP(M, SCALE_DIM_X), alignment) for all valid row counts, so the dequantize and quantize scale layouts are consistent.
  • Output metadata incomplete for VARYING_LAST_DIM / VARYING_BOTH_DIMS (cast.cpp): NoneQuantizer::create_grouped_tensor always sets last_dims = None and computes tensor_offsets from first_dims × logical_last_dim, which is wrong for VARYING_BOTH_DIMS (the total element count is not the per-tensor column count). Downstream code that needs to iterate individual tensors in the output cannot do so reliably.
  • Type mismatch in C++ test (test_dequantize_mxfp8_grouped.cu): size_t * device pointers are declared as kNVTEInt64 in NVTEBasicTensor structs; the kernel reinterprets them as int64_t *.
  • Device memory leaks on test failures: Early ASSERT_* returns bypass cudaFree calls.
  • cp_async_bulk_wait_group_read<0>() called by all threads without the is_master_thread guard, inconsistent with the guarded pattern used throughout the rest of the kernel (harmless but confusing).

Confidence Score: 3/5

  • The core CUDA kernel logic is sound but the PyTorch extension produces an incomplete output GroupedTensor for VARYING_LAST_DIM/VARYING_BOTH_DIMS inputs; merge after addressing the metadata propagation issue.
  • The kernel algorithm and scale arithmetic are correct and well-tested. The main concern is the PyTorch-layer output GroupedTensor missing last_dims and having incorrect tensor_offsets for VARYING_LAST_DIM and VARYING_BOTH_DIMS cases — this is a real functional gap for API consumers. The test type mismatch and memory leaks are secondary. Score of 3 reflects a new feature that works correctly at the kernel level but has a meaningful metadata gap in the Python API layer.
  • transformer_engine/pytorch/csrc/extensions/cast.cpp (missing last_dims/tensor_offsets propagation to output) and tests/cpp/operator/test_dequantize_mxfp8_grouped.cu (size_t/int64_t type mismatch and memory leaks on failure).

Important Files Changed

Filename Overview
transformer_engine/common/cast/mxfp8/group_dequantize_mxfp8.cuh New CUDA kernel implementing grouped MXFP8 dequantization using Blackwell TMA for double-buffered global↔shared transfers. Scale offset arithmetic for VARYING_LAST_DIM is correct (equivalence of padded_rows/SCALE_DIM_X and DIVUP_TO_MULTIPLE(rows/SCALE_DIM_X, alignment) is mathematically verified), but cp_async_bulk_wait_group_read<0>() is unnecessarily called by all 128 threads outside the master-thread guard.
transformer_engine/common/cast/dispatch/dequantize.cuh Clean dispatcher adding group_dequantize_helper routing MXFP8 grouped tensors to the new kernel; correctly gates on CC ≥ 10.0 and delegates all other scaling modes to an NVTE_ERROR.
transformer_engine/common/cast/cast.cu New nvte_group_dequantize C API entry point follows the same pattern as nvte_group_quantize; straightforward and correct.
transformer_engine/pytorch/csrc/extensions/cast.cpp New group_dequantize PyTorch extension correctly builds the input GroupedTensorWrapper and calls nvte_group_dequantize, but the output GroupedTensor created by NoneQuantizer::create_grouped_tensor is missing last_dims and has incorrect/missing tensor_offsets for VARYING_LAST_DIM and VARYING_BOTH_DIMS inputs, which can break downstream per-tensor iteration.
tests/cpp/operator/test_dequantize_mxfp8_grouped.cu Thorough bitwise-correctness test comparing grouped vs. per-tensor single dequantize, but has three issues: (1) size_t * device pointers declared as kNVTEInt64 — a type mismatch; (2) device memory leaks on early ASSERT_EQ failures inside the per-tensor loop; (3) the already-flagged off-by-one in offsets_shape.data[0].

Sequence Diagram

sequenceDiagram
    participant PY as Python caller
    participant EXT as PyTorch extension<br/>(cast.cpp)
    participant CAPI as C API<br/>(cast.cu)
    participant DISP as Dispatcher<br/>(dequantize.cuh)
    participant TMA as update_tma_descriptors<br/>(CUDA kernel)
    participant KERN as group_dequantize_mxfp8_kernel<br/>(CUDA kernel)

    PY->>EXT: group_dequantize(input, otype)
    EXT->>EXT: build input GroupedTensorWrapper<br/>(rowwise/colwise data + scales +<br/>first_dims / last_dims / offsets)
    EXT->>EXT: NoneQuantizer::create_grouped_tensor()<br/>allocate output buffer
    EXT->>CAPI: nvte_group_dequantize(input, output, stream)
    CAPI->>DISP: group_dequantize_helper(input, output, stream)
    DISP->>DISP: check CC ≥ 10.0 (Blackwell)
    DISP->>DISP: mxfp8::group_dequantize(&input, output, stream)

    alt is_single_tensor (SAME_BOTH_DIMS or VARYING_FIRST_DIM)
        DISP->>KERN: launch with static TMA descriptors
    else multi-tensor (VARYING_LAST_DIM or VARYING_BOTH_DIMS)
        DISP->>TMA: update_tma_descriptors<<<num_tensors, 32>>><br/>per-tensor TMA descriptor → g_tensor_maps[]
        TMA-->>DISP: (async, same stream)
        DISP->>KERN: launch with per-tensor g_tensor_maps[]
    end

    KERN->>KERN: fence_acquire_tensormap (if multi-tensor)
    KERN->>KERN: get_current_tensor_id() — binary search on offsets_ptr
    KERN->>KERN: compute scales_base_offset per tensor
    loop ITERATIONS=8 (double-buffered, BUFFER_DIM_Y=16)
        KERN->>KERN: TMA load in_sh[buff] ← global FP8 data
        KERN->>KERN: read e8m0 scale → block_scale = exp2(biased_exp)
        KERN->>KERN: out = block_scale × float(in) → OType
        KERN->>KERN: TMA store out_sh[buff] → global output
    end
    KERN-->>EXT: (async kernel, CUDA stream)
    EXT-->>PY: return output GroupedTensor
Loading

Last reviewed commit: "Merge remote-trackin..."

nvte_set_grouped_tensor_param(in_group_tensor,
NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor,
sizeof(in_data_tensor));
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect offsets shape — off-by-one

offsets_shape.data[0] is set to num_tensors, but the offsets array is a standard CSR-style sentinel array with num_tensors + 1 entries (the last entry stores the total element count). The allocation uses (num_tensors + 1) * sizeof(size_t) on line 132 and offsets_h is declared with num_tensors + 1 on line 408. get_current_tensor_id (borrowed from the quantize path) searches over offsets_ptr[0 .. num_tensors], so it will access one element past the declared shape.

Suggested change
} else {
offsets_shape.data[0] = num_tensors + 1;

Comment on lines +119 to +134
size_t *first_dims_d;
size_t *last_dims_d;
size_t *offsets_d;

cudaMalloc((void **)&in_data_d, in_data_size);
cudaMalloc((void **)&out_grouped_d, out_data_size);
cudaMalloc((void **)&in_scales_d, scales_size);
cudaMalloc((void **)&first_dims_d, first_dims_size);
cudaMalloc((void **)&last_dims_d, last_dims_size);
cudaMalloc((void **)&offsets_d, offsets_size);

cudaMemcpy(in_data_d, in_data_h.data(), in_data_size, cudaMemcpyHostToDevice);
cudaMemcpy(in_scales_d, in_scales_h.data(), scales_size, cudaMemcpyHostToDevice);
cudaMemcpy(first_dims_d, first_dims_h.data(), first_dims_size, cudaMemcpyHostToDevice);
cudaMemcpy(last_dims_d, last_dims_h.data(), last_dims_size, cudaMemcpyHostToDevice);
cudaMemcpy(offsets_d, offsets_h.data(), offsets_size, cudaMemcpyHostToDevice);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 size_t * device pointers declared as kNVTEInt64

first_dims_d, last_dims_d, and offsets_d are allocated as size_t * (unsigned 64-bit), but passed to NVTEBasicTensor with type kNVTEInt64 (signed 64-bit). The kernel then reinterpret-casts them to int64_t *. While this works on 64-bit Linux platforms where the layouts are identical and values are non-negative, it is technically undefined behavior and will silently miscompute on any platform where sizeof(size_t) != sizeof(int64_t).

The pointers should be declared as int64_t * to match the declared tensor type, keeping the host-side std::vector<size_t> but performing an explicit cast when copying:

int64_t *first_dims_d;
int64_t *last_dims_d;
int64_t *offsets_d;

And correspondingly cast the host-side vectors on copy, or change the host vectors to std::vector<int64_t> as well.

Comment on lines +286 to +302
TensorWrapper output_w;
output_w.set_rowwise_data(single_out_d, otype, single_shape);

nvte_dequantize(input_w.data(), output_w.data(), 0);
cudaDeviceSynchronize();
err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << "Single-tensor dequantize failed for tensor " << t << ": "
<< cudaGetErrorString(err);

// Copy reference output to host
cudaMemcpy(out_ref_h.data() + data_offset, single_out_d, single_out_size,
cudaMemcpyDeviceToHost);

cudaFree(single_in_d);
cudaFree(single_out_d);
cudaFree(single_scales_d);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Device memory leak on early assertion failures

When ASSERT_EQ(err, cudaSuccess) (line 292) fires and terminates the test early, the per-iteration device allocations (single_in_d, single_out_d, single_scales_d) are leaked. Furthermore, the outer-scope allocations (in_data_d, out_grouped_d, in_scales_d, etc.) are also never freed because ASSERT_* macros in GTest cause an early return, bypassing the cudaFree calls at the end of performTest.

Consider wrapping device allocations in RAII handles (e.g., a small CudaPtr<T> wrapper) or using EXPECT_EQ followed by explicit cleanup so memory is always released, even in failure paths.

Comment on lines +321 to +328
// Create output GroupedTensor using NoneQuantizer.
NoneQuantizer q{py::none()};
auto [out_cpp, out_py] = q.create_grouped_tensor(num_tensors, logical_shape, otype, py::none(),
first_dims, logical_first_dim, logical_last_dim);

NVTE_SCOPED_GIL_RELEASE({
nvte_group_dequantize(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream());
});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 last_dims and tensor_offsets not propagated to the output GroupedTensor

NoneQuantizer::create_grouped_tensor always sets kwargs["last_dims"] = py::none() and computes tensor_offsets only from first_dims (via build_grouped_tensor_offsets). This means:

  • For VARYING_LAST_DIM inputs (where first_dims is None): the output's tensor_offsets will be None, making it impossible for downstream code to locate individual tensors in the output buffer.
  • For VARYING_BOTH_DIMS inputs (where first_dims is present but last_dims also varies): build_grouped_tensor_offsets computes offsets as first_dims[i] * logical_last_dim, where logical_last_dim is the total element count — producing completely wrong per-tensor byte boundaries on the output object.

The input's last_dims and tensor_offsets tensors are extracted at lines 279 and 278 but are only used to populate input_cpp; they are never forwarded to the output. Downstream consumers that need to split or index into the output GroupedTensor will silently get wrong results.

Consider either passing last_dims and tensor_offsets through to the output, or documenting clearly that callers are responsible for preserving this metadata from the input.

Comment on lines +326 to +327
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 cp_async_bulk_wait_group_read<0>() called by all threads without is_master_thread guard

All 128 threads execute ptx::cp_async_bulk_wait_group_read<0>(), but only the master thread ever issued TMA bulk-store operations. Non-master threads have no outstanding bulk-copy groups, so wait_group<0> is effectively a no-op for them. While harmless today (the PTX instruction is well-defined with zero outstanding groups), this is inconsistent with the guarded pattern used everywhere else in the kernel (e.g., lines 315–324) and may confuse future readers.

Suggested change
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
if (is_master_thread) {
ptx::cp_async_bulk_wait_group_read<0>();
}
__syncthreads();

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Dequantization support for the grouped tensor - MXFP8

1 participant