Skip to content

[None][feat] LTX-2 Ulysses cross-attention for v2a with audio padding#14044

Draft
luyiyun1021 wants to merge 2 commits into
NVIDIA:mainfrom
luyiyun1021:feat/ltx2-ulysses-v2a-cross-attn
Draft

[None][feat] LTX-2 Ulysses cross-attention for v2a with audio padding#14044
luyiyun1021 wants to merge 2 commits into
NVIDIA:mainfrom
luyiyun1021:feat/ltx2-ulysses-v2a-cross-attn

Conversation

@luyiyun1021
Copy link
Copy Markdown
Collaborator

@luyiyun1021 luyiyun1021 commented May 12, 2026

Summary

Adds a Ulysses sequence-parallel wrapper for LTX-2's v2a (video→audio) cross-attention. Replaces the per-block video-K/V all-gather with a 3-collective wrapper (Q a2a + fused K|V 5D a2a + output a2a), cutting per-rank communication ~4× on this direction.

The opposite direction (a2v) is intentionally left on the existing all-gather path — the wrapper would cost ~30× more bytes there (see analysis below).

Why only v2a (and not a2v) needs Ulysses

LTX-2 cross-attn is extremely asymmetric: video tokens vastly outnumber audio tokens.

For our reference config (768×1280, 121 frames, ulysses_size U=4):

  • T_v = 16 × 24 × 40 = 15360 (video latent tokens)
  • T_a = round(121/24 × 25) = 126 (audio latent tokens, after VAE 4× downsample)
  • T_v / T_a ≈ 122×

Per-rank receive bytes per block (B=1, D_a=2048, bf16=2 B/elem):

AllGather path (current main, both cross-attns):

  • R_AG = 2 · (U-1)/U · B · T_kv · D_a · 2
  • a2v (K/V = audio, T_kv=126): 0.78 MB per rank/block
  • v2a (K/V = video, T_kv=15360): 94 MB per rank/block

Ulysses 3-collective path (Q + fused K|V + output a2a):

  • R_Uly = (U-1)/U² · B · (2·T_q + 2·T_kv) · D_a · 2
  • a2v (Q=video, K/V=audio): 23.8 MB per rank/block (+30× vs AG ❌)
  • v2a (Q=audio, K/V=video): 23.8 MB per rank/block (−4.0× vs AG ✅)

Break-even: Uly < AG iff T_q + T_kv < U · T_kv, i.e. T_q / T_kv < U − 1.

  • v2a: T_q/T_kv = 0.0082 ≪ 3 → wrapper wins
  • a2v: T_q/T_kv = 122 ≫ 3 → wrapper loses (would need U > 123)

→ Wrapper applied to v2a only.

Design

UlyssesCrossAttention (parallel.py, +107 lines)

3-collective wrapper for asymmetric cross-attn (S_q ≠ S_kv):

  • Q all_to_all_4d: [B, S_q/U, H, D] → [B, S_q, H/U, D]
  • Fused K|V 5D a2a: stack(K, V, dim=2) → all_to_all_5d → [B, S_kv, 2, H_kv/U, D] (1 collective for both)
  • Inner backend (VANILLA SDPA, H_kv/U heads) on full-seq head-sharded
  • Output all_to_all_4d: [B, S_q, H/U, D] → [B, S_q/U, H, D]

world_size==1 fast path skips all 3 collectives. support_fused_qkv()=False (S_q≠S_kv).

Audio padding for divisibility

T_a = round(num_frames/fps × 25) is generally not divisible by U. For typical LTX-2 num_frames = 8k+1 (e.g. 113, 121, 129), T_a % 4 == 2 always.

Gated by parallel.audio_pad_for_ulysses (default True):

  • True: LTXModel.forward pads audio on entry (zero-pad latent, repeat-last positions/timesteps/PE), attaches [B, T_a_padded] bool mask to TransformerArgs, strips padded tail on exit. Pad ratio ~1.6% at T_a=126, U=4 → 128.
  • False: original behavior — wrapper engages iff T_a % U == 0, else falls back to all-gather.

Mask threading

  • VanillaAttention.forward accepts optional key_padding_mask: [B, S_kv] bool (True=valid), expanded to [B, 1, 1, S_kv] as SDPA attn_mask (non-causal only).
  • audio_attn1 (self-attn) and audio_to_video_attn (a2v) consume audio.audio_padding_mask.
  • video_to_audio_attn (v2a) does not consume the mask: video K/V is unpadded; pad audio Q is stripped at LTXModel.forward exit.

Runtime toggle

LTX2Attention.use_ulysses_cross=True (set on video_to_audio_attn constructor) instantiates an inner backend at (H/U, H_kv/U) heads and wraps it. set_ulysses_active toggles wrapper vs base full-H plain backend — used by configure_audio_ulysses and Stage 2's set_ulysses_enabled(False).

Mutual exclusion with Attention2D: gate uses ulysses_size > 1 (not seq_parallel_size > 1). VisualGenMapping enforces Attention2D ↔ Ulysses are mutually exclusive, so the wrapper naturally skips under Attention2D mode.

Test plan

CPU + gloo (no GPU required):

  • tests/.../test_ulysses_cross_attention.py — init + forward shape ws∈{1,2,4}, ws==1 bit-exact fast path, parity vs full SDPA ws∈{2,4}
  • tests/.../test_ulysses_cross_rope_equivariance.py — verifies rope-then-a2a == a2a-then-full-rope on Q and K sides ws∈{2,4}

E2E (8× B200, dit_cfg_size=2 × dit_ulysses_size=4, 768×1280, 121 frames):

  • Baseline (audio_pad_for_ulysses=false → wrapper inactive, all-gather path)
  • Opt (audio_pad_for_ulysses=true → v2a wrapper engaged)
  • Both produce visually identical video
  • Bit-level numerical parity under real weights — not yet added

Performance (8× B200, CFG=2 × Uly=4, num_inference_steps=40, 3 timed runs after warmup)

Wall-clock e2e

Config run 0 run 1 run 2 avg Δ vs baseline
Baseline (AG, no wrapper) 9.500s 9.486s 9.504s 9.497s
v2a Uly wrapper 9.142s 9.059s 9.071s 9.091s −0.406s (−4.3%)

NCCL kernels per rank per denoise step (TLLM_PROFILE_VISUAL_GEN_START_STOP=5-5 + nsys --capture-range=cudaProfilerApi --cuda-graph-trace=node)

AllGather (count × avg) SendRecv (count × avg) Total
Baseline (no wrapper) 207 × 89 us = 18.5 ms 205 × 91 us = 18.7 ms 37.2 ms
v2a Uly 196 × 15 us = 3.0 ms 528 × 43 us = 22.6 ms 25.6 ms
Δ −15.5 ms (−84%) +3.9 ms (+21%) −11.6 ms (−31%)

(Baseline per-rank = total ÷ 8 ranks; v2a Uly = per-rank average across 8 ranks. Per-rank breakdown for baseline not preserved.)

40 steps × ~10 ms saved per rank per step ≈ 400 ms ≈ measured e2e Δ = 406 ms

The wrapper trades 15.5 ms of AllGather (removed, video K/V no longer all-gathered) for 3.9 ms of additional SendRecv (added Q + K|V + output a2a). Net −11.6 ms NCCL per rank per step maps nearly 1:1 onto wall-clock under CUDA Graph capture.

Open follow-ups

  • a2v wrapper variant is implemented on the branch (toggleable via use_ulysses_cross on audio_to_video_attn) but left disabled by default per the byte analysis. Could be exposed as a YAML option for ablations.
  • Bit-level numerical parity test under real LTX-2 weights (Stage 1 e2e) not yet added.
  • Linter-driven style tweaks to UlyssesAttention._forward_fused / _forward_unfused are bundled in this commit (post-a2a seq_len kwarg threading); unrelated to the v2a wrapper logic but kept here to avoid a trivial separate commit.

🤖 Generated with Claude Code

Adds a Ulysses sequence-parallel wrapper for the video-to-audio (v2a)
cross-attention in LTX-2. v2a is the asymmetric direction where the
wrapper saves communication: Q is audio (small, T_a~126) and K/V are
video (large, T_v=15360 for 768x1280). Replacing the per-block video
K/V all-gather with a Q+(K|V fused)+output three-collective pattern
cuts per-rank communication volume by ~4x.

The opposite a2v direction (Q=video, K/V=audio) is left on the existing
all-gather path because the wrapper would move two video-sized tensors
(Q a2a + output a2a) in exchange for replacing a cheap audio-K/V
all-gather — net 30x MORE bytes per rank at U=4.

# Communication analysis (per rank, per block, B=1, D_a=2048, bf16)

For LTX-2 at 768x1280 / 121f / U=4: T_v = 15360, T_a = 126.

AllGather path (current main for both cross-attns):
    2 * (U-1)/U * B * T_kv * D_a * 2

  a2v (K/V = audio, T_kv = T_a = 126):     ~0.78 MB per rank/block
  v2a (K/V = video, T_kv = T_v = 15360):    94    MB per rank/block

UlyssesCrossAttention path (3 a2a):
    (U-1)/U^2 * B * (T_q + 2*T_kv + T_q) * D_a * 2

  a2v (T_q = T_v, T_kv = T_a):              24    MB per rank/block  (30x MORE)
  v2a (T_q = T_a, T_kv = T_v):              24    MB per rank/block  (~4x LESS)

Break-even: Uly < AG iff (T_q + T_kv) < U * T_kv, equivalently
T_q/T_kv < U - 1. With T_v/T_a ~ 122, the wrapper is only profitable
when the small modality is on the Q side. Hence v2a only.

# Design

- `UlyssesCrossAttention` (parallel.py, +107 lines): three-collective
  wrapper for S_q != S_kv. Q all_to_all_4d, fused K|V 5D all_to_all,
  output all_to_all_4d. Stacks K|V on a new dim of size 2 so both
  tensors travel in a single collective. world_size==1 fast path
  bypasses all three.

- `audio_pad_for_ulysses` flag (ParallelConfig, default True): when set,
  LTXModel pads the audio modality at forward entry so T_a is divisible
  by ulysses_size, attaches a [B, T_a_padded] bool validity mask to
  TransformerArgs, and strips the padded tail on exit. Required because
  T_a is derived from num_frames/fps and is usually NOT divisible
  (e.g. 121 video frames at 24fps gives T_a=126; 126%4=2). Pad scheme:
  zero-pad latent, repeat-last for positions / per-token timesteps,
  repeat-last for the (cos,sin) PE tuple via a seq-dim auto-detect helper.

- `LTX2Attention.use_ulysses_cross`: builds the inner backend at
  (H/U, H_kv/U) heads and wraps with UlyssesCrossAttention. Gated on
  ulysses_size > 1 (skipped under Attention2D mode which is mutually
  exclusive with Ulysses per VisualGenMapping). set_ulysses_active
  toggles between the wrapper and the base plain backend at runtime
  (used by configure_audio_ulysses and Stage 2 fall-back).

- VanillaAttention gains an optional `key_padding_mask` kwarg (expanded
  to [B,1,1,S_kv] as SDPA attn_mask, non-causal only). The
  `audio_to_video_attn` call site passes the mask so attention zeros
  out audio pad slots. `video_to_audio_attn` does not need a mask
  (video K/V is unpadded; pad audio Q is stripped on output).

- `_shard_transformer_args` passes the full-seq audio_padding_mask
  through unchanged; the mask is identical across Ulysses ranks.

# Tests (CPU + gloo, no GPU required)

- test_ulysses_cross_attention.py: wrapper init, forward shape at
  world_size in {2, 4}, world_size==1 bit-exact fast path, and parity
  vs full SDPA on shared seed inputs.
- test_ulysses_cross_rope_equivariance.py: verifies the
  rope-then-a2a == a2a-then-full-rope identity that the wrapper relies
  on (since rope is applied locally on sharded inputs in
  LTX2Attention.forward and then a2a'd, the per-position rotation must
  be preserved by the redistribution).

# Performance (8x B200, dit_cfg_size=2 x dit_ulysses_size=4,
                768x1280, 121 frames, num_inference_steps=40,
                3 timed runs after warmup)

Wall clock e2e:

  Baseline (no wrapper):    9.497s
  v2a wrapper:              9.091s    (-0.406s, -4.3%)

NCCL kernel time, per rank, per step (1 denoise step capture via
TLLM_PROFILE_VISUAL_GEN_START_STOP=5-5 + nsys --capture-range=
cudaProfilerApi --cuda-graph-trace=node):

  Baseline:  ~37 ms NCCL (1640 SendRecv + 1653 AllGather summed
             across 8 ranks; per rank ~205 + ~207 kernels)
  v2a Uly:   ~27 ms NCCL average (528 SendRecv + 196 AllGather per
             rank; rank-0 outlier 37.7 ms, rest 24-29 ms)

  Per-rank per-step savings: ~10 ms.
  40 steps * 10 ms = 400 ms ~ measured 406 ms e2e.

Audio padding cost is negligible at typical T_a (126%4=2 padded to 128;
1.6% padded tokens, masked out in attention).

Disabling the wrapper at runtime is supported via
`parallel.audio_pad_for_ulysses: false` — `_audio_is_sharded` then only
becomes True when T_a%U==0, and the wrapper falls back to the existing
all-gather path otherwise.

Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
…change)

Addresses post-review feedback on the v2a Ulysses wrapper. Pure
refactoring; no functional changes.

1. PE pad moved into prepare_text_cache (done once per generate()
   instead of per forward step).
2. _pad_pe takes seq_dim explicitly (was fragile auto-detect by
   shape match); raises on out-of-range seq_dim.
3. Forward audio pad guard reduced from 4 conjuncts to 2.
4. Removed defensive ``if model_config is not None else True`` fallback.
5. PE pad pre-assign + overwrite block removed from forward (subsumed
   by NVIDIA#1).
6. configure_audio_ulysses two branches flattened to a single formula.
7. v2a fallback restructured to flat if/elif/else (was nested).

Plus a stale comment fix: ``Ulysses wrap (a2v in LTX-2)`` →
``(v2a in LTX-2)``.

Diff: ~44 insertions, ~57 deletions in transformer_ltx2.py only.
Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant