Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 104 additions & 18 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

import os
import sys
import copy
import json
import traceback
import logging
from contextlib import nullcontext
import torch
Expand Down Expand Up @@ -209,10 +212,10 @@ def run_dpa_with_cp(
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
config = model_configs_flash_attn[model]
config = copy.deepcopy(model_configs_flash_attn[model])
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
config = copy.deepcopy(model_configs_fused_attn[model])
assert config.attn_mask_type in [
"causal",
"no_mask",
Expand All @@ -223,18 +226,11 @@ def run_dpa_with_cp(
else:
config.attn_mask_type = "padding"

# set up distributed group
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
else:
device_count = torch.cuda.device_count()
device = rank % device_count
torch.cuda.set_device(device)
# Process group is managed by main(); one init/destroy per torchrun, not per config.
assert dist.is_initialized(), "dist.init_process_group must be called before run_dpa_with_cp"
world_size = dist.get_world_size()
rank = dist.get_rank()
logging.info(f"[Rank {rank}] Setup: world_size {world_size}")
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)

# set up communication group for CP
cp_comm_ranks = range(world_size)
Expand Down Expand Up @@ -630,7 +626,6 @@ def run_dpa_with_cp(
== 0
)
else:
# Forward-only: reshape only out/out_ for comparison
out = out.index_select(0, seq_idx_q).contiguous()
out_ = out_

Expand Down Expand Up @@ -762,14 +757,105 @@ def run_dpa_with_cp(
)
logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches")

# destroy distribution group
dist.destroy_process_group()
# Destroy per-config communication groups so they don't leak into the next
# config in batch mode. The global process group is torn down by main().
dist.destroy_process_group(cp_comm_group)
if cp_comm_type == "a2a+p2p":
for sg in cp_comm_sub_groups:
dist.destroy_process_group(sg)


# Env vars set by run_dpa_with_cp; cleared between batch configs to prevent leakage.
_TRANSIENT_ENV_KEYS = (
"NVTE_FP8_DPA_BWD",
"NVTE_DPA_FP8CS_O_in_F16",
"NVTE_FLASH_ATTN",
"NVTE_FUSED_ATTN",
"NVTE_ALLOW_NONDETERMINISTIC_ALGO",
)


def _init_distributed():
"""Init NCCL process group + CUDA device once per torchrun invocation."""
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
device_count = torch.cuda.device_count()
# Prefer LOCAL_RANK when available (set by torchrun / torch.distributed.launch);
# fall back to RANK % device_count for single-node runs.
local_rank = int(os.getenv("LOCAL_RANK", str(rank % device_count)))
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
return rank, world_size


def _run_single_config(kwargs):
"""Run one config, return ``(ok, error_message)``.

Re-seeds RNG before each config so results are deterministic and
order-independent within a batch.
"""
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
try:
run_dpa_with_cp(**kwargs)
return True, None
except BaseException: # noqa: BLE001 - capture any failure for per-config reporting
return False, traceback.format_exc()


def main(**kwargs):
run_dpa_with_cp(**kwargs)
"""Entry point: single-config (``key=val ...``) or batch (``batch_config_json=<path>``)."""
batch_path = kwargs.pop("batch_config_json", None)
rank, _ = _init_distributed()
try:
if batch_path is None:
run_dpa_with_cp(**kwargs)
else:
with open(batch_path, "r") as f:
configs = json.load(f)
assert isinstance(
configs, list
), f"batch_config_json must be a JSON list, got {type(configs)}"
results_path = batch_path + ".results.json"
results = []

def _flush_results():
if rank != 0:
return
# Atomic write: tmp + rename so the reader never sees partial JSON.
tmp_path = results_path + ".tmp"
with open(tmp_path, "w") as f:
json.dump(results, f)
os.replace(tmp_path, results_path)

for cfg in configs:
for env_key in _TRANSIENT_ENV_KEYS:
os.environ.pop(env_key, None)
ok, err = _run_single_config(cfg)
# Aggregate ok across ranks so a non-rank-0 failure (e.g. a
# per-partition compare assertion that fires only on rank > 0)
# is not silently swallowed when only rank 0 writes the result.
ok_tensor = torch.tensor(1 if ok else 0, dtype=torch.int32, device="cuda")
dist.all_reduce(ok_tensor, op=dist.ReduceOp.MIN)
ok_aggregate = bool(ok_tensor.item())
if not ok_aggregate and ok and err is None:
err = "Failed on a non-zero rank (see subprocess stderr for traceback)"
results.append({"ok": ok_aggregate, "error": err})
_flush_results()
try:
dist.barrier()
except BaseException: # noqa: BLE001
results[-1]["ok"] = False
if results[-1]["error"] is None:
results[-1]["error"] = traceback.format_exc()
_flush_results()
break
torch.cuda.empty_cache()
Comment on lines +831 to +853
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Failures on non-rank-0 workers are silently swallowed

Only rank 0 calls _flush_results(), so if rank 1 (or any non-zero rank) raises an exception inside run_dpa_with_cp() — e.g., an output-mismatch assertion in the per-rank comparison loop — but rank 0 succeeds, the result is written as ok=True and the failure is never surfaced.

This scenario is realistic: each rank checks correctness for its own partition of the sequence independently, so a CP bug that corrupts only one partition would be missed. The dist.barrier() would still succeed because rank 1's exception happens after all NCCL collectives (in the comparison phase), leaving both ranks able to reach the barrier.

A minimal fix is to aggregate the per-rank success flag before flushing: collect all ranks' ok values with dist.all_reduce (using ReduceOp.MIN or ReduceOp.BAND) and record the aggregate result on rank 0.

finally:
if dist.is_initialized():
dist.destroy_process_group()


if __name__ == "__main__":
kwargs = dict(arg.split("=") for arg in sys.argv[2:])
kwargs = dict(arg.split("=", 1) for arg in sys.argv[2:])
main(**kwargs)
Loading
Loading