[PyTorch] Batch CP attention tests in single torchrun to amortize NCC…#2965
[PyTorch] Batch CP attention tests in single torchrun to amortize NCC…#2965sudhakarsingh27 wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR replaces the one-torchrun-per-test approach for context-parallel attention tests with a session-scoped fixture that batches configs by GPU count and runs each batch in a single torchrun, eliminating 1.5–3 hours of redundant NCCL init/destroy overhead across ~650–800 tests. A small bugfix also restricts the deterministic-FA3 disable condition to training-only scenarios.
Confidence Score: 4/5Safe to merge with one fix recommended: communication groups created inside run_dpa_with_cp can be leaked on exception, accumulating NCCL communicators across a batch. The batch dispatch architecture is sound and the all_reduce aggregation addresses the known non-rank-0 failure issue. The gap is that cp_comm_group (and sub-groups for a2a+p2p) are created via dist.new_group() early in run_dpa_with_cp but only destroyed at the very end of the function — if any exception escapes before that point, NCCL communicators are never released. In a batch of 16 configs with intermittent failures, leaked communicators accumulate within the same worker process and can exhaust NCCL's internal table limits, causing subsequent configs to fail with NCCL errors rather than surfacing the original fault. tests/pytorch/attention/run_attention_with_cp.py — specifically the cp_comm_group and cp_comm_sub_groups cleanup path when run_dpa_with_cp raises mid-execution. Important Files Changed
|
| for cfg in configs: | ||
| for env_key in _TRANSIENT_ENV_KEYS: | ||
| os.environ.pop(env_key, None) | ||
| ok, err = _run_single_config(cfg) | ||
| results.append({"ok": ok, "error": err}) | ||
| _flush_results() | ||
| try: | ||
| dist.barrier() | ||
| except BaseException: # noqa: BLE001 | ||
| results[-1]["ok"] = False | ||
| if results[-1]["error"] is None: | ||
| results[-1]["error"] = traceback.format_exc() | ||
| _flush_results() | ||
| break | ||
| torch.cuda.empty_cache() |
There was a problem hiding this comment.
Failures on non-rank-0 workers are silently swallowed
Only rank 0 calls _flush_results(), so if rank 1 (or any non-zero rank) raises an exception inside run_dpa_with_cp() — e.g., an output-mismatch assertion in the per-rank comparison loop — but rank 0 succeeds, the result is written as ok=True and the failure is never surfaced.
This scenario is realistic: each rank checks correctness for its own partition of the sequence independently, so a CP bug that corrupts only one partition would be missed. The dist.barrier() would still succeed because rank 1's exception happens after all NCCL collectives (in the comparison phase), leaving both ranks able to reach the barrier.
A minimal fix is to aggregate the per-rank success flag before flushing: collect all ranks' ok values with dist.all_reduce (using ReduceOp.MIN or ReduceOp.BAND) and record the aggregate result on rank 0.
fa189b0 to
0e9fc1f
Compare
| try: | ||
| params = tuple(callspec.params[p] for p in param_names) | ||
| except KeyError: | ||
| continue | ||
| kwargs, skip_reason = preparer(*params) |
There was a problem hiding this comment.
Unguarded preparer exception fails all CP tests
If get_available_attention_backends (or any other preparer call) raises an unexpected exception for a single item — e.g., a driver error, an unhandled edge-case in the backend-detection logic — the exception propagates out of _collect_configs and tears down the session-scoped fixture setup. Every test that requests _cp_batch_results is then reported as an ERROR rather than as a SKIP or individual FAIL. Previously, a failure in the test-function body only affected that one test. A try/except around the preparer call (logging the error and synthesising a "kind": "error" entry in by_nodeid) would preserve the original per-test isolation.
ba8dce2 to
4b9a4b7
Compare
…L init
Each parametrized CP test currently spawns its own torchrun process and
pays 5-15s of NCCL init/destroy. With ~650-800 collected tests this
adds up to 1.5-3 hours of pure setup overhead.
This change introduces a session-scoped fixture that:
1. Calls per-test ``_prepare_*`` helpers to get either a skip reason or
a kwargs dict for the worker.
2. Groups runnable configs by ``num_gpus`` and chunks them into batches
of CP_TEST_BATCH_SIZE (default 16).
3. Launches one torchrun per chunk; the worker initialises NCCL once
and runs all configs in the chunk inside the same world.
Per-config results are flushed to JSON after every config so a crash
mid-batch still leaves earlier results intact. Set CP_TEST_BATCH_SIZE=1
to bisect a failing batch.
Also includes a small bugfix in dot_product_attention/utils.py: the
deterministic-FA3 disable condition was firing for any head_dim_qk > 128
(including inference); restrict it to is_training and large head dims.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
139c955 to
f34fda0
Compare
for more information, see https://pre-commit.ci
…L init
Each parametrized CP test currently spawns its own torchrun process and pays 5-15s of NCCL init/destroy. With ~650-800 collected tests this adds up to 1.5-3 hours of pure setup overhead.
This change introduces a session-scoped fixture that:
_prepare_*helpers to get either a skip reason or a kwargs dict for the worker.num_gpusand chunks them into batches of CP_TEST_BATCH_SIZE (default 16).Per-config results are flushed to JSON after every config so a crash mid-batch still leaves earlier results intact. Set CP_TEST_BATCH_SIZE=1 to bisect a failing batch.
Also includes a small bugfix in dot_product_attention/utils.py: the deterministic-FA3 disable condition was firing for any head_dim_qk > 128 (including inference); restrict it to is_training and large head dims.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: