[Common, pyTorch] Grouped MXFP8 dequantize support#2722
[Common, pyTorch] Grouped MXFP8 dequantize support#2722ptrendx wants to merge 14 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds grouped MXFP8 dequantization support, introducing a new Blackwell-only TMA-based CUDA kernel ( Key findings:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant PY as Python caller
participant EXT as PyTorch extension<br/>(cast.cpp)
participant CAPI as C API<br/>(cast.cu)
participant DISP as Dispatcher<br/>(dequantize.cuh)
participant TMA as update_tma_descriptors<br/>(CUDA kernel)
participant KERN as group_dequantize_mxfp8_kernel<br/>(CUDA kernel)
PY->>EXT: group_dequantize(input, otype)
EXT->>EXT: build input GroupedTensorWrapper<br/>(rowwise/colwise data + scales +<br/>first_dims / last_dims / offsets)
EXT->>EXT: NoneQuantizer::create_grouped_tensor()<br/>allocate output buffer
EXT->>CAPI: nvte_group_dequantize(input, output, stream)
CAPI->>DISP: group_dequantize_helper(input, output, stream)
DISP->>DISP: check CC ≥ 10.0 (Blackwell)
DISP->>DISP: mxfp8::group_dequantize(&input, output, stream)
alt is_single_tensor (SAME_BOTH_DIMS or VARYING_FIRST_DIM)
DISP->>KERN: launch with static TMA descriptors
else multi-tensor (VARYING_LAST_DIM or VARYING_BOTH_DIMS)
DISP->>TMA: update_tma_descriptors<<<num_tensors, 32>>><br/>per-tensor TMA descriptor → g_tensor_maps[]
TMA-->>DISP: (async, same stream)
DISP->>KERN: launch with per-tensor g_tensor_maps[]
end
KERN->>KERN: fence_acquire_tensormap (if multi-tensor)
KERN->>KERN: get_current_tensor_id() — binary search on offsets_ptr
KERN->>KERN: compute scales_base_offset per tensor
loop ITERATIONS=8 (double-buffered, BUFFER_DIM_Y=16)
KERN->>KERN: TMA load in_sh[buff] ← global FP8 data
KERN->>KERN: read e8m0 scale → block_scale = exp2(biased_exp)
KERN->>KERN: out = block_scale × float(in) → OType
KERN->>KERN: TMA store out_sh[buff] → global output
end
KERN-->>EXT: (async kernel, CUDA stream)
EXT-->>PY: return output GroupedTensor
Last reviewed commit: "Merge remote-trackin..." |
| nvte_set_grouped_tensor_param(in_group_tensor, | ||
| NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor, | ||
| sizeof(in_data_tensor)); | ||
| } else { |
There was a problem hiding this comment.
Incorrect offsets shape — off-by-one
offsets_shape.data[0] is set to num_tensors, but the offsets array is a standard CSR-style sentinel array with num_tensors + 1 entries (the last entry stores the total element count). The allocation uses (num_tensors + 1) * sizeof(size_t) on line 132 and offsets_h is declared with num_tensors + 1 on line 408. get_current_tensor_id (borrowed from the quantize path) searches over offsets_ptr[0 .. num_tensors], so it will access one element past the declared shape.
| } else { | |
| offsets_shape.data[0] = num_tensors + 1; |
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
| size_t *first_dims_d; | ||
| size_t *last_dims_d; | ||
| size_t *offsets_d; | ||
|
|
||
| cudaMalloc((void **)&in_data_d, in_data_size); | ||
| cudaMalloc((void **)&out_grouped_d, out_data_size); | ||
| cudaMalloc((void **)&in_scales_d, scales_size); | ||
| cudaMalloc((void **)&first_dims_d, first_dims_size); | ||
| cudaMalloc((void **)&last_dims_d, last_dims_size); | ||
| cudaMalloc((void **)&offsets_d, offsets_size); | ||
|
|
||
| cudaMemcpy(in_data_d, in_data_h.data(), in_data_size, cudaMemcpyHostToDevice); | ||
| cudaMemcpy(in_scales_d, in_scales_h.data(), scales_size, cudaMemcpyHostToDevice); | ||
| cudaMemcpy(first_dims_d, first_dims_h.data(), first_dims_size, cudaMemcpyHostToDevice); | ||
| cudaMemcpy(last_dims_d, last_dims_h.data(), last_dims_size, cudaMemcpyHostToDevice); | ||
| cudaMemcpy(offsets_d, offsets_h.data(), offsets_size, cudaMemcpyHostToDevice); |
There was a problem hiding this comment.
size_t * device pointers declared as kNVTEInt64
first_dims_d, last_dims_d, and offsets_d are allocated as size_t * (unsigned 64-bit), but passed to NVTEBasicTensor with type kNVTEInt64 (signed 64-bit). The kernel then reinterpret-casts them to int64_t *. While this works on 64-bit Linux platforms where the layouts are identical and values are non-negative, it is technically undefined behavior and will silently miscompute on any platform where sizeof(size_t) != sizeof(int64_t).
The pointers should be declared as int64_t * to match the declared tensor type, keeping the host-side std::vector<size_t> but performing an explicit cast when copying:
int64_t *first_dims_d;
int64_t *last_dims_d;
int64_t *offsets_d;And correspondingly cast the host-side vectors on copy, or change the host vectors to std::vector<int64_t> as well.
| TensorWrapper output_w; | ||
| output_w.set_rowwise_data(single_out_d, otype, single_shape); | ||
|
|
||
| nvte_dequantize(input_w.data(), output_w.data(), 0); | ||
| cudaDeviceSynchronize(); | ||
| err = cudaGetLastError(); | ||
| ASSERT_EQ(err, cudaSuccess) << "Single-tensor dequantize failed for tensor " << t << ": " | ||
| << cudaGetErrorString(err); | ||
|
|
||
| // Copy reference output to host | ||
| cudaMemcpy(out_ref_h.data() + data_offset, single_out_d, single_out_size, | ||
| cudaMemcpyDeviceToHost); | ||
|
|
||
| cudaFree(single_in_d); | ||
| cudaFree(single_out_d); | ||
| cudaFree(single_scales_d); | ||
| } |
There was a problem hiding this comment.
Device memory leak on early assertion failures
When ASSERT_EQ(err, cudaSuccess) (line 292) fires and terminates the test early, the per-iteration device allocations (single_in_d, single_out_d, single_scales_d) are leaked. Furthermore, the outer-scope allocations (in_data_d, out_grouped_d, in_scales_d, etc.) are also never freed because ASSERT_* macros in GTest cause an early return, bypassing the cudaFree calls at the end of performTest.
Consider wrapping device allocations in RAII handles (e.g., a small CudaPtr<T> wrapper) or using EXPECT_EQ followed by explicit cleanup so memory is always released, even in failure paths.
| // Create output GroupedTensor using NoneQuantizer. | ||
| NoneQuantizer q{py::none()}; | ||
| auto [out_cpp, out_py] = q.create_grouped_tensor(num_tensors, logical_shape, otype, py::none(), | ||
| first_dims, logical_first_dim, logical_last_dim); | ||
|
|
||
| NVTE_SCOPED_GIL_RELEASE({ | ||
| nvte_group_dequantize(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); | ||
| }); |
There was a problem hiding this comment.
last_dims and tensor_offsets not propagated to the output GroupedTensor
NoneQuantizer::create_grouped_tensor always sets kwargs["last_dims"] = py::none() and computes tensor_offsets only from first_dims (via build_grouped_tensor_offsets). This means:
- For
VARYING_LAST_DIMinputs (wherefirst_dimsisNone): the output'stensor_offsetswill beNone, making it impossible for downstream code to locate individual tensors in the output buffer. - For
VARYING_BOTH_DIMSinputs (wherefirst_dimsis present butlast_dimsalso varies):build_grouped_tensor_offsetscomputes offsets asfirst_dims[i] * logical_last_dim, wherelogical_last_dimis the total element count — producing completely wrong per-tensor byte boundaries on the output object.
The input's last_dims and tensor_offsets tensors are extracted at lines 279 and 278 but are only used to populate input_cpp; they are never forwarded to the output. Downstream consumers that need to split or index into the output GroupedTensor will silently get wrong results.
Consider either passing last_dims and tensor_offsets through to the output, or documenting clearly that callers are responsible for preserving this metadata from the input.
| ptx::cp_async_bulk_wait_group_read<0>(); | ||
| __syncthreads(); |
There was a problem hiding this comment.
cp_async_bulk_wait_group_read<0>() called by all threads without is_master_thread guard
All 128 threads execute ptx::cp_async_bulk_wait_group_read<0>(), but only the master thread ever issued TMA bulk-store operations. Non-master threads have no outstanding bulk-copy groups, so wait_group<0> is effectively a no-op for them. While harmless today (the PTX instruction is well-defined with zero outstanding groups), this is inconsistent with the guarded pattern used everywhere else in the kernel (e.g., lines 315–324) and may confuse future readers.
| ptx::cp_async_bulk_wait_group_read<0>(); | |
| __syncthreads(); | |
| if (is_master_thread) { | |
| ptx::cp_async_bulk_wait_group_read<0>(); | |
| } | |
| __syncthreads(); |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Description
Support dequantization for MXFP8 grouped tensors.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: