Skip to content

[PyTorch] Batch CP attention tests in single torchrun to amortize NCC…#2965

Open
sudhakarsingh27 wants to merge 2 commits intoNVIDIA:mainfrom
sudhakarsingh27:sudhakars/cp_test_batching_pr
Open

[PyTorch] Batch CP attention tests in single torchrun to amortize NCC…#2965
sudhakarsingh27 wants to merge 2 commits intoNVIDIA:mainfrom
sudhakarsingh27:sudhakars/cp_test_batching_pr

Conversation

@sudhakarsingh27
Copy link
Copy Markdown
Collaborator

…L init

Each parametrized CP test currently spawns its own torchrun process and pays 5-15s of NCCL init/destroy. With ~650-800 collected tests this adds up to 1.5-3 hours of pure setup overhead.

This change introduces a session-scoped fixture that:

  1. Calls per-test _prepare_* helpers to get either a skip reason or a kwargs dict for the worker.
  2. Groups runnable configs by num_gpus and chunks them into batches of CP_TEST_BATCH_SIZE (default 16).
  3. Launches one torchrun per chunk; the worker initialises NCCL once and runs all configs in the chunk inside the same world.

Per-config results are flushed to JSON after every config so a crash mid-batch still leaves earlier results intact. Set CP_TEST_BATCH_SIZE=1 to bisect a failing batch.

Also includes a small bugfix in dot_product_attention/utils.py: the deterministic-FA3 disable condition was firing for any head_dim_qk > 128 (including inference); restrict it to is_training and large head dims.

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 6, 2026

Greptile Summary

This PR replaces the one-torchrun-per-test approach for context-parallel attention tests with a session-scoped fixture that batches configs by GPU count and runs each batch in a single torchrun, eliminating 1.5–3 hours of redundant NCCL init/destroy overhead across ~650–800 tests. A small bugfix also restricts the deterministic-FA3 disable condition to training-only scenarios.

  • Batched dispatch: A _cp_batch_results session fixture dry-runs each test body in collect mode to record kwargs and skip decisions, groups runnable configs into chunks of CP_TEST_BATCH_SIZE (default 16), launches one torchrun per chunk, and flushes per-config JSON results after each run so mid-batch crashes don't lose earlier results.
  • Worker-side changes: run_dpa_with_cp now asserts the process group is already initialized (managed by main()), cleans up only per-config communication groups (not the global group), clears transient env vars between configs, and uses dist.all_reduce(MIN) to surface failures on non-rank-0 workers.
  • Env-var arg parsing fix: sys.argv splitting now uses split(\"=\", 1) to correctly handle values that contain = characters.

Confidence Score: 4/5

Safe to merge with one fix recommended: communication groups created inside run_dpa_with_cp can be leaked on exception, accumulating NCCL communicators across a batch.

The batch dispatch architecture is sound and the all_reduce aggregation addresses the known non-rank-0 failure issue. The gap is that cp_comm_group (and sub-groups for a2a+p2p) are created via dist.new_group() early in run_dpa_with_cp but only destroyed at the very end of the function — if any exception escapes before that point, NCCL communicators are never released. In a batch of 16 configs with intermittent failures, leaked communicators accumulate within the same worker process and can exhaust NCCL's internal table limits, causing subsequent configs to fail with NCCL errors rather than surfacing the original fault.

tests/pytorch/attention/run_attention_with_cp.py — specifically the cp_comm_group and cp_comm_sub_groups cleanup path when run_dpa_with_cp raises mid-execution.

Important Files Changed

Filename Overview
tests/pytorch/attention/run_attention_with_cp.py Batch execution mode added: NCCL init/destroy moved to main(), per-config env cleanup, all_reduce for cross-rank failure aggregation. Communication groups created inside run_dpa_with_cp are leaked on exception, posing a resource-exhaustion risk in batch runs.
tests/pytorch/attention/test_attention_with_cp.py Session-scoped fixture replaces per-test torchrun launches with batched dispatch; dry-run collect mode and _run_or_fetch pattern are sound. Minor fragility: _dry_run_item relies on fixed positional argument ordering in test function signatures.

Comments Outside Diff (1)

  1. tests/pytorch/attention/run_attention_with_cp.py, line 238-762 (link)

    P1 NCCL communicator leak on mid-config exception

    cp_comm_group (and cp_comm_sub_groups for a2a+p2p) are created at line 238 via dist.new_group() but destroyed only at the very end of run_dpa_with_cp. If the function raises anywhere in between — an assertion failure, a CUDA OOM, a comparison mismatch — _run_single_config catches the exception and returns, and neither dist.destroy_process_group(cp_comm_group) nor the sub-group destroys are ever called. NCCL communicators accumulate silently across configs in the same batch. With batches of 16+ and any flaky configs, this can exhaust NCCL's internal communicator-table limit and corrupt subsequent configs with "NCCL error: invalid usage" rather than surfacing the original failure.

    Wrapping the body after group creation in a try/finally would guarantee cleanup.

Reviews (4): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +831 to +845
for cfg in configs:
for env_key in _TRANSIENT_ENV_KEYS:
os.environ.pop(env_key, None)
ok, err = _run_single_config(cfg)
results.append({"ok": ok, "error": err})
_flush_results()
try:
dist.barrier()
except BaseException: # noqa: BLE001
results[-1]["ok"] = False
if results[-1]["error"] is None:
results[-1]["error"] = traceback.format_exc()
_flush_results()
break
torch.cuda.empty_cache()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Failures on non-rank-0 workers are silently swallowed

Only rank 0 calls _flush_results(), so if rank 1 (or any non-zero rank) raises an exception inside run_dpa_with_cp() — e.g., an output-mismatch assertion in the per-rank comparison loop — but rank 0 succeeds, the result is written as ok=True and the failure is never surfaced.

This scenario is realistic: each rank checks correctness for its own partition of the sequence independently, so a CP bug that corrupts only one partition would be missed. The dist.barrier() would still succeed because rank 1's exception happens after all NCCL collectives (in the comparison phase), leaving both ranks able to reach the barrier.

A minimal fix is to aggregate the per-rank success flag before flushing: collect all ranks' ok values with dist.all_reduce (using ReduceOp.MIN or ReduceOp.BAND) and record the aggregate result on rank 0.

@sudhakarsingh27 sudhakarsingh27 force-pushed the sudhakars/cp_test_batching_pr branch from fa189b0 to 0e9fc1f Compare May 6, 2026 23:01
Comment on lines +554 to +558
try:
params = tuple(callspec.params[p] for p in param_names)
except KeyError:
continue
kwargs, skip_reason = preparer(*params)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Unguarded preparer exception fails all CP tests

If get_available_attention_backends (or any other preparer call) raises an unexpected exception for a single item — e.g., a driver error, an unhandled edge-case in the backend-detection logic — the exception propagates out of _collect_configs and tears down the session-scoped fixture setup. Every test that requests _cp_batch_results is then reported as an ERROR rather than as a SKIP or individual FAIL. Previously, a failure in the test-function body only affected that one test. A try/except around the preparer call (logging the error and synthesising a "kind": "error" entry in by_nodeid) would preserve the original per-test isolation.

@sudhakarsingh27 sudhakarsingh27 force-pushed the sudhakars/cp_test_batching_pr branch 2 times, most recently from ba8dce2 to 4b9a4b7 Compare May 6, 2026 23:34
…L init

Each parametrized CP test currently spawns its own torchrun process and
pays 5-15s of NCCL init/destroy. With ~650-800 collected tests this
adds up to 1.5-3 hours of pure setup overhead.

This change introduces a session-scoped fixture that:
  1. Calls per-test ``_prepare_*`` helpers to get either a skip reason or
     a kwargs dict for the worker.
  2. Groups runnable configs by ``num_gpus`` and chunks them into batches
     of CP_TEST_BATCH_SIZE (default 16).
  3. Launches one torchrun per chunk; the worker initialises NCCL once
     and runs all configs in the chunk inside the same world.

Per-config results are flushed to JSON after every config so a crash
mid-batch still leaves earlier results intact. Set CP_TEST_BATCH_SIZE=1
to bisect a failing batch.

Also includes a small bugfix in dot_product_attention/utils.py: the
deterministic-FA3 disable condition was firing for any head_dim_qk > 128
(including inference); restrict it to is_training and large head dims.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 force-pushed the sudhakars/cp_test_batching_pr branch from 139c955 to f34fda0 Compare May 6, 2026 23:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant