Motivation
The Triton entropy loss kernel supports tensor-parallel training natively via its group parameter: each rank holds vocab / tp_size logits, all-reduces two scalars per row (max and sum-exp), then computes loss and grad locally. This avoids materializing the full logits on any rank.
The PyTorch alternative requires an all-gather of the full logit tensor first — O(tokens × vocab) communication. At realistic scale this is prohibitive:
- Llama 3.1 405B: vocab=128K, tokens=8K, TP=8 → 16 GB all-gather per step
The Triton path is not just faster; it is the only feasible approach at large vocab × TP. The current single-GPU benchmark does not demonstrate this. This issue tracks building a multi-GPU benchmark that makes it concrete.
Variants to benchmark
Three qualitatively different approaches:
-
triton_tp — existing triton_entropy_loss_forward_backward(..., group=group). Shards vocab across ranks, all-reduces two scalars per row. O(tokens) communication.
-
pytorch_tp_manual — same algorithm in PyTorch without Triton: local_max = logits.max(-1) → all_reduce(MAX) → local_sum = (logits − max).exp().sum(-1) → all_reduce(SUM) → loss. Tests whether Triton fusion still wins when both paths use the same O(tokens) communication pattern.
-
pytorch_gather — all-gather logits to full vocab on each rank, then F.cross_entropy. O(tokens × vocab) communication. Included as a reference to show where the naive approach becomes infeasible; expected to OOM at large vocab × TP.
Shapes
Fix tokens=4096. Sweep (vocab, tp_size) pairs:
| vocab |
tp_size |
shard / rank |
| 32768 |
2 |
16384 |
| 32768 |
4 |
8192 |
| 65536 |
4 |
16384 |
| 131072 |
4 |
32768 |
| 131072 |
8 |
16384 |
Infrastructure changes
Multi-process runner
The current runner.py is single-process. Two options:
Option A — new tools/benchmark/run_tp.py entry point using torch.multiprocessing.spawn(worker, nprocs=tp_size). Each worker: initializes dist.init_process_group, creates a TP process group, runs benchmark variants with group=group, rank 0 collects and prints results.
Option B — extend __main__.py to detect a --tp N flag and re-launch itself via torchrun --nproc_per_node=N.
Option A is simpler to implement. Option B integrates more cleanly with the existing CLI.
Timing
dist.barrier() + torch.cuda.synchronize() before and after each timed region so all ranks agree on wall time.
- Report max latency across ranks (rank 0 collects via
dist.all_reduce(MAX)).
- Communication time is included automatically since the all-reduce is inside the kernel call.
OOM guard for pytorch_gather
Wrap in try/except and report OOM in the table instead of a time. This makes the table show exactly where the naive approach becomes infeasible.
Expected outcome
At vocab=131072, TP=8:
| variant |
result |
triton_tp |
fast (~O(tokens) communication) |
pytorch_tp_manual |
slightly slower (no kernel fusion, same communication) |
pytorch_gather |
OOM or ~10–20× slower (16 GB all-gather dominates) |
Files
- New:
tools/benchmark/run_tp.py (or extend __main__.py)
- Modified:
tools/benchmark/bench_entropy_loss.py — add TP variants alongside existing single-GPU variants
Motivation
The Triton entropy loss kernel supports tensor-parallel training natively via its
groupparameter: each rank holdsvocab / tp_sizelogits, all-reduces two scalars per row (max and sum-exp), then computes loss and grad locally. This avoids materializing the full logits on any rank.The PyTorch alternative requires an all-gather of the full logit tensor first — O(tokens × vocab) communication. At realistic scale this is prohibitive:
The Triton path is not just faster; it is the only feasible approach at large vocab × TP. The current single-GPU benchmark does not demonstrate this. This issue tracks building a multi-GPU benchmark that makes it concrete.
Variants to benchmark
Three qualitatively different approaches:
triton_tp— existingtriton_entropy_loss_forward_backward(..., group=group). Shards vocab across ranks, all-reduces two scalars per row. O(tokens) communication.pytorch_tp_manual— same algorithm in PyTorch without Triton:local_max = logits.max(-1)→all_reduce(MAX)→local_sum = (logits − max).exp().sum(-1)→all_reduce(SUM)→ loss. Tests whether Triton fusion still wins when both paths use the same O(tokens) communication pattern.pytorch_gather— all-gather logits to full vocab on each rank, thenF.cross_entropy. O(tokens × vocab) communication. Included as a reference to show where the naive approach becomes infeasible; expected to OOM at large vocab × TP.Shapes
Fix
tokens=4096. Sweep(vocab, tp_size)pairs:Infrastructure changes
Multi-process runner
The current
runner.pyis single-process. Two options:Option A — new
tools/benchmark/run_tp.pyentry point usingtorch.multiprocessing.spawn(worker, nprocs=tp_size). Each worker: initializesdist.init_process_group, creates a TP process group, runs benchmark variants withgroup=group, rank 0 collects and prints results.Option B — extend
__main__.pyto detect a--tp Nflag and re-launch itself viatorchrun --nproc_per_node=N.Option A is simpler to implement. Option B integrates more cleanly with the existing CLI.
Timing
dist.barrier()+torch.cuda.synchronize()before and after each timed region so all ranks agree on wall time.dist.all_reduce(MAX)).OOM guard for
pytorch_gatherWrap in
try/exceptand reportOOMin the table instead of a time. This makes the table show exactly where the naive approach becomes infeasible.Expected outcome
At vocab=131072, TP=8:
triton_tppytorch_tp_manualpytorch_gatherFiles
tools/benchmark/run_tp.py(or extend__main__.py)tools/benchmark/bench_entropy_loss.py— add TP variants alongside existing single-GPU variants