Use precision-matched quantized references in INT4 matmul tests#19323
Use precision-matched quantized references in INT4 matmul tests#19323digantdesai wants to merge 1 commit intomainfrom
Conversation
Replace eager float32 references with precision-matched quantized references that align with each kernel's internal dequant precision: - dequant_w4_to_bf16: bitwise exact vs pure-Python dequant (was atol=0.01) - int4_matmul: cuBLAS bf16 GEMM reference (both truncate to bf16) - int4_matvec: f32 matmul reference (both keep dequant in f32, atol=1e-3 vs prior atol=1.0) Co-authored-by: Claude <noreplyanthropic.com>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19323
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Cancelled Job, 1 Pending, 6 Unrelated FailuresAs of commit 70ade20 with merge base 8ae05c2 ( NEW FAILURES - The following jobs have failed:
CANCELLED JOB - The following job was cancelled. Please retry:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
This PR updates CUDA INT4 matmul/dequant test references to be precision-matched with the kernels’ internal dequantization precision (bf16 vs f32), aiming to reduce flakiness and eliminate reliance on SNR-based assertions. It also adds additional API-contract assertions and expands test coverage with tail-K and directed byte-pattern fixtures.
Changes:
- Add input/shape/dtype assertions to
int4_matvecanddequant_w4_to_bf16. - Rewrite INT4 correctness tests to use a Python dequant reference and
allclosevs precision-matched references (bf16 or f32 dequant as appropriate). - Add directed fixtures (endpoint nibbles, group-boundary scales), tail-K cases, non-contiguity cases, and contract tests.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| backends/cuda/triton/kernels/int4_matmul.py | Adds stronger input validation for matvec and dequant wrappers (dtype/shape checks, K constraint). |
| backends/cuda/tests/test_int4_matmul.py | Replaces SNR-based references with precision-matched references; adds more targeted fixtures and contract tests. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| K = x.shape[-1] | ||
| N = w_packed.shape[0] | ||
| assert w_packed.shape == (N, K // 2), f"w_packed shape {w_packed.shape} != ({N}, {K // 2})" | ||
| assert w_scale.shape == (N, K // group_size), f"w_scale shape {w_scale.shape} != ({N}, {K // group_size})" |
| assert w_packed.shape == (N, K // 2), f"w_packed shape {w_packed.shape} != ({N}, {K // 2})" | ||
| assert w_scale.shape == (N, K // group_size), f"w_scale shape {w_scale.shape} != ({N}, {K // group_size})" | ||
| assert K % 32 == 0, f"K={K} must be a multiple of 32 for vectorized loads" |
| assert w_packed.dtype == torch.int8 | ||
| assert w_scale.dtype == torch.bfloat16 | ||
| N, K_half = w_packed.shape | ||
| K = K_half * 2 |
| Replaces the former cross-kernel cosine test with stricter parity | ||
| . |
| w_int4 = torch.stack([lo, hi], dim=-1).reshape(N, K) | ||
| scale_expanded = scale.float().repeat_interleave(group_size, dim=1) | ||
| w_f32 = w_int4.float() * scale_expanded |
|
@claude review this PR. |
|
Claude finished @digantdesai's task in 1m 38s —— View job Review of PR #19323
Overall this is a well-motivated rewrite. Replacing eager float32 references with precision-matched quantized references is the right approach — it eliminates flaky SNR-based assertions by testing what the kernel actually computes rather than comparing to a numerically-different golden. The test structure is clean and the coverage is solid. Issues1. CI failure: This test compares Options:
2. Missing The assert K % group_size == 0, f"K={K} must be divisible by group_size={group_size}"This is a correctness invariant the kernel relies on — making it explicit catches bad inputs early. 3. The matvec kernel uses masked loads ( Wait — looking again, 4. Docstring artifact ( Replaces the former cross-kernel cosine test with stricter parity
.
"""Stray Suggestions (non-blocking)5.
w_grouped = w_int4.float().reshape(N, K // group_size, group_size)
w_f32 = (w_grouped * scale.float().unsqueeze(-1)).reshape(N, K)This is just a test helper so it's non-blocking, but it would reduce peak memory during 6. Consider Both VerdictThe approach is sound and the test coverage is comprehensive. The main actionable item is fixing the |
Replace eager float32 references with precision-matched quantized references that align with each kernel's internal dequant precision, so we are less flaky, so we don't need SNR tests.