[None][feat] LTX-2 Ulysses cross-attention for v2a with audio padding#14044
Draft
luyiyun1021 wants to merge 2 commits into
Draft
[None][feat] LTX-2 Ulysses cross-attention for v2a with audio padding#14044luyiyun1021 wants to merge 2 commits into
luyiyun1021 wants to merge 2 commits into
Conversation
Adds a Ulysses sequence-parallel wrapper for the video-to-audio (v2a)
cross-attention in LTX-2. v2a is the asymmetric direction where the
wrapper saves communication: Q is audio (small, T_a~126) and K/V are
video (large, T_v=15360 for 768x1280). Replacing the per-block video
K/V all-gather with a Q+(K|V fused)+output three-collective pattern
cuts per-rank communication volume by ~4x.
The opposite a2v direction (Q=video, K/V=audio) is left on the existing
all-gather path because the wrapper would move two video-sized tensors
(Q a2a + output a2a) in exchange for replacing a cheap audio-K/V
all-gather — net 30x MORE bytes per rank at U=4.
# Communication analysis (per rank, per block, B=1, D_a=2048, bf16)
For LTX-2 at 768x1280 / 121f / U=4: T_v = 15360, T_a = 126.
AllGather path (current main for both cross-attns):
2 * (U-1)/U * B * T_kv * D_a * 2
a2v (K/V = audio, T_kv = T_a = 126): ~0.78 MB per rank/block
v2a (K/V = video, T_kv = T_v = 15360): 94 MB per rank/block
UlyssesCrossAttention path (3 a2a):
(U-1)/U^2 * B * (T_q + 2*T_kv + T_q) * D_a * 2
a2v (T_q = T_v, T_kv = T_a): 24 MB per rank/block (30x MORE)
v2a (T_q = T_a, T_kv = T_v): 24 MB per rank/block (~4x LESS)
Break-even: Uly < AG iff (T_q + T_kv) < U * T_kv, equivalently
T_q/T_kv < U - 1. With T_v/T_a ~ 122, the wrapper is only profitable
when the small modality is on the Q side. Hence v2a only.
# Design
- `UlyssesCrossAttention` (parallel.py, +107 lines): three-collective
wrapper for S_q != S_kv. Q all_to_all_4d, fused K|V 5D all_to_all,
output all_to_all_4d. Stacks K|V on a new dim of size 2 so both
tensors travel in a single collective. world_size==1 fast path
bypasses all three.
- `audio_pad_for_ulysses` flag (ParallelConfig, default True): when set,
LTXModel pads the audio modality at forward entry so T_a is divisible
by ulysses_size, attaches a [B, T_a_padded] bool validity mask to
TransformerArgs, and strips the padded tail on exit. Required because
T_a is derived from num_frames/fps and is usually NOT divisible
(e.g. 121 video frames at 24fps gives T_a=126; 126%4=2). Pad scheme:
zero-pad latent, repeat-last for positions / per-token timesteps,
repeat-last for the (cos,sin) PE tuple via a seq-dim auto-detect helper.
- `LTX2Attention.use_ulysses_cross`: builds the inner backend at
(H/U, H_kv/U) heads and wraps with UlyssesCrossAttention. Gated on
ulysses_size > 1 (skipped under Attention2D mode which is mutually
exclusive with Ulysses per VisualGenMapping). set_ulysses_active
toggles between the wrapper and the base plain backend at runtime
(used by configure_audio_ulysses and Stage 2 fall-back).
- VanillaAttention gains an optional `key_padding_mask` kwarg (expanded
to [B,1,1,S_kv] as SDPA attn_mask, non-causal only). The
`audio_to_video_attn` call site passes the mask so attention zeros
out audio pad slots. `video_to_audio_attn` does not need a mask
(video K/V is unpadded; pad audio Q is stripped on output).
- `_shard_transformer_args` passes the full-seq audio_padding_mask
through unchanged; the mask is identical across Ulysses ranks.
# Tests (CPU + gloo, no GPU required)
- test_ulysses_cross_attention.py: wrapper init, forward shape at
world_size in {2, 4}, world_size==1 bit-exact fast path, and parity
vs full SDPA on shared seed inputs.
- test_ulysses_cross_rope_equivariance.py: verifies the
rope-then-a2a == a2a-then-full-rope identity that the wrapper relies
on (since rope is applied locally on sharded inputs in
LTX2Attention.forward and then a2a'd, the per-position rotation must
be preserved by the redistribution).
# Performance (8x B200, dit_cfg_size=2 x dit_ulysses_size=4,
768x1280, 121 frames, num_inference_steps=40,
3 timed runs after warmup)
Wall clock e2e:
Baseline (no wrapper): 9.497s
v2a wrapper: 9.091s (-0.406s, -4.3%)
NCCL kernel time, per rank, per step (1 denoise step capture via
TLLM_PROFILE_VISUAL_GEN_START_STOP=5-5 + nsys --capture-range=
cudaProfilerApi --cuda-graph-trace=node):
Baseline: ~37 ms NCCL (1640 SendRecv + 1653 AllGather summed
across 8 ranks; per rank ~205 + ~207 kernels)
v2a Uly: ~27 ms NCCL average (528 SendRecv + 196 AllGather per
rank; rank-0 outlier 37.7 ms, rest 24-29 ms)
Per-rank per-step savings: ~10 ms.
40 steps * 10 ms = 400 ms ~ measured 406 ms e2e.
Audio padding cost is negligible at typical T_a (126%4=2 padded to 128;
1.6% padded tokens, masked out in attention).
Disabling the wrapper at runtime is supported via
`parallel.audio_pad_for_ulysses: false` — `_audio_is_sharded` then only
becomes True when T_a%U==0, and the wrapper falls back to the existing
all-gather path otherwise.
Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
…change) Addresses post-review feedback on the v2a Ulysses wrapper. Pure refactoring; no functional changes. 1. PE pad moved into prepare_text_cache (done once per generate() instead of per forward step). 2. _pad_pe takes seq_dim explicitly (was fragile auto-detect by shape match); raises on out-of-range seq_dim. 3. Forward audio pad guard reduced from 4 conjuncts to 2. 4. Removed defensive ``if model_config is not None else True`` fallback. 5. PE pad pre-assign + overwrite block removed from forward (subsumed by NVIDIA#1). 6. configure_audio_ulysses two branches flattened to a single formula. 7. v2a fallback restructured to flat if/elif/else (was nested). Plus a stale comment fix: ``Ulysses wrap (a2v in LTX-2)`` → ``(v2a in LTX-2)``. Diff: ~44 insertions, ~57 deletions in transformer_ltx2.py only. Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a Ulysses sequence-parallel wrapper for LTX-2's v2a (video→audio) cross-attention. Replaces the per-block video-K/V all-gather with a 3-collective wrapper (Q a2a + fused K|V 5D a2a + output a2a), cutting per-rank communication ~4× on this direction.
The opposite direction (a2v) is intentionally left on the existing all-gather path — the wrapper would cost ~30× more bytes there (see analysis below).
Why only v2a (and not a2v) needs Ulysses
LTX-2 cross-attn is extremely asymmetric: video tokens vastly outnumber audio tokens.
For our reference config (768×1280, 121 frames, ulysses_size U=4):
Per-rank receive bytes per block (B=1, D_a=2048, bf16=2 B/elem):
AllGather path (current main, both cross-attns):
Ulysses 3-collective path (Q + fused K|V + output a2a):
Break-even: Uly < AG iff T_q + T_kv < U · T_kv, i.e. T_q / T_kv < U − 1.
→ Wrapper applied to v2a only.
Design
UlyssesCrossAttention (parallel.py, +107 lines)
3-collective wrapper for asymmetric cross-attn (S_q ≠ S_kv):
world_size==1 fast path skips all 3 collectives. support_fused_qkv()=False (S_q≠S_kv).
Audio padding for divisibility
T_a = round(num_frames/fps × 25) is generally not divisible by U. For typical LTX-2 num_frames = 8k+1 (e.g. 113, 121, 129), T_a % 4 == 2 always.
Gated by parallel.audio_pad_for_ulysses (default True):
Mask threading
Runtime toggle
LTX2Attention.use_ulysses_cross=True (set on video_to_audio_attn constructor) instantiates an inner backend at (H/U, H_kv/U) heads and wraps it. set_ulysses_active toggles wrapper vs base full-H plain backend — used by configure_audio_ulysses and Stage 2's set_ulysses_enabled(False).
Mutual exclusion with Attention2D: gate uses ulysses_size > 1 (not seq_parallel_size > 1). VisualGenMapping enforces Attention2D ↔ Ulysses are mutually exclusive, so the wrapper naturally skips under Attention2D mode.
Test plan
CPU + gloo (no GPU required):
E2E (8× B200, dit_cfg_size=2 × dit_ulysses_size=4, 768×1280, 121 frames):
Performance (8× B200, CFG=2 × Uly=4, num_inference_steps=40, 3 timed runs after warmup)
Wall-clock e2e
NCCL kernels per rank per denoise step (TLLM_PROFILE_VISUAL_GEN_START_STOP=5-5 + nsys --capture-range=cudaProfilerApi --cuda-graph-trace=node)
(Baseline per-rank = total ÷ 8 ranks; v2a Uly = per-rank average across 8 ranks. Per-rank breakdown for baseline not preserved.)
40 steps × ~10 ms saved per rank per step ≈ 400 ms ≈ measured e2e Δ = 406 ms ✓
The wrapper trades 15.5 ms of AllGather (removed, video K/V no longer all-gathered) for 3.9 ms of additional SendRecv (added Q + K|V + output a2a). Net −11.6 ms NCCL per rank per step maps nearly 1:1 onto wall-clock under CUDA Graph capture.
Open follow-ups
🤖 Generated with Claude Code