Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,113 @@ def support_fused_qkv(cls) -> bool:
return True


class UlyssesCrossAttention(AttentionBackend):
"""
Ulysses Sequence Parallelism wrapper for cross-attention where
``S_q != S_kv`` and K/V are pre-projected.

Uses three collectives total:
1. Q all-to-all: [B, S_q/U, H, D] -> [B, S_q, H/U, D]
2. Fused K|V 5D all-to-all: stack K, V on a new dim of size 2 so both
tensors travel in one collective:
[B, S_kv/U, 2, H_kv, D] -> [B, S_kv, 2, H_kv/U, D]
3. Output all-to-all: [B, S_q, H/U, D] -> [B, S_q/U, H, D]

Per-rank receive volume is ``((U-1)/U^2) * total_bytes`` per collective,
i.e. a factor of ``U`` less than an all-gather of the same total tensor.

The inner backend must be instantiated at sharded head counts
``(H/U, H_kv/U)`` so its forward sees the post-a2a shape correctly.

Caller contract:
- ``q``, ``k``, ``v`` are seq-sharded, full-head NHD tensors. The Q-side
and KV-side seq dims are independent.
- ``S_q`` and ``S_kv`` must each be divisible by ``world_size``.
- All other kwargs (``attention_mask``, etc.) are forwarded transparently
to the inner backend after the a2a phase.
- ``world_size == 1`` is a fast path that skips all three collectives.
- ``support_fused_qkv() == False`` (S_q != S_kv precludes Q+KV fusion).
"""

def __init__(
self,
inner_backend: AttentionBackend,
process_group: Optional[torch.distributed.ProcessGroup] = None,
):
self.inner_backend = inner_backend
self.process_group = process_group
self._preferred_layout = AttentionTensorLayout.NHD

self.head_dim = inner_backend.head_dim
self.sharded_num_heads = inner_backend.num_heads
self.sharded_num_kv_heads = getattr(inner_backend, "num_kv_heads", self.sharded_num_heads)

try:
self.world_size = torch.distributed.get_world_size(group=process_group)
except (RuntimeError, ValueError):
self.world_size = 1

# Exposed head counts reflect the full, unsharded model width. The
# inner backend is at (H/U, H_kv/U) so that on each rank, after the
# a2a distributes heads across ranks, the shape matches.
self.num_heads = self.sharded_num_heads * self.world_size
self.num_kv_heads = self.sharded_num_kv_heads * self.world_size

def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
**kwargs,
) -> torch.Tensor:
assert q.dim() == 4, f"q must be 4D [B, S_q/U, H, D], got {q.dim()}D"
assert k.dim() == 4, f"k must be 4D [B, S_kv/U, H_kv, D], got {k.dim()}D"
assert v.dim() == 4, f"v must be 4D [B, S_kv/U, H_kv, D], got {v.dim()}D"
assert q.shape[0] == k.shape[0] == v.shape[0], (
f"q/k/v batch mismatch: {q.shape[0]}, {k.shape[0]}, {v.shape[0]}"
)
assert k.shape[1] == v.shape[1], f"k/v seq shard mismatch: k={k.shape[1]}, v={v.shape[1]}"

if self.world_size > 1:
# Q: [B, S_q/U, H, D] -> [B, S_q, H/U, D]
q = all_to_all_4d(q, scatter_dim=2, gather_dim=1, process_group=self.process_group)
# Fused K|V: stack on new dim=2 so both travel in one collective.
# [B, S_kv/U, 2, H_kv, D] -> [B, S_kv, 2, H_kv/U, D]
kv = torch.stack([k, v], dim=2)
kv = all_to_all_5d(kv, scatter_dim=3, gather_dim=1, process_group=self.process_group)
k, v = kv.unbind(dim=2)
k = k.contiguous()
v = v.contiguous()

if self.inner_backend.preferred_layout == AttentionTensorLayout.HND:
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

out = self.inner_backend.forward(q=q, k=k, v=v, **kwargs)

if self.inner_backend.preferred_layout == AttentionTensorLayout.HND:
out = out.transpose(1, 2).contiguous()
else:
out = out.contiguous()

if self.world_size > 1:
# Output: [B, S_q, H/U, D] -> [B, S_q/U, H, D]
out = all_to_all_4d(out, scatter_dim=1, gather_dim=2, process_group=self.process_group)

return out

@property
def preferred_layout(self) -> AttentionTensorLayout:
"""Preferred tensor layout: [B, S, H, D]"""
return self._preferred_layout

@classmethod
def support_fused_qkv(cls) -> bool:
# S_q != S_kv precludes stacking Q with K/V in a single collective.
return False


class Attention2DAttention(AttentionBackend):
"""
Attention2D Context Parallelism wrapper for video-generation inference.
Expand Down
18 changes: 18 additions & 0 deletions tensorrt_llm/_torch/visual_gen/attention_backend/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def forward(
v: torch.Tensor,
*,
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.FULL,
key_padding_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Expand All @@ -83,6 +84,10 @@ def forward(
k: Key tensor [batch_size, num_kv_heads, seq_len_kv, head_dim]
v: Value tensor [batch_size, num_kv_heads, seq_len_kv, head_dim]
attention_mask: Attention mask type (CAUSAL or FULL)
key_padding_mask: Optional ``[B, S_kv]`` bool tensor; True = valid,
False = pad. Expanded internally to ``[B, 1, 1, S_kv]`` and
passed as ``attn_mask`` to SDPA. Only the non-causal branch is
supported (LTX-2 cross-attn is non-causal).

Returns:
Output tensor [batch_size, num_heads, seq_len, head_dim]
Expand All @@ -99,6 +104,19 @@ def forward(
f"Invalid v shape: expected [B={q.shape[0]}, H_kv, S_kv, D={self.head_dim}], got {v.shape}"
)

if key_padding_mask is not None:
assert not is_causal, "key_padding_mask is not supported with causal attention"
assert key_padding_mask.dim() == 2 and key_padding_mask.shape == (
q.shape[0],
k.shape[2],
), (
f"Invalid key_padding_mask shape: expected [B={q.shape[0]}, "
f"S_kv={k.shape[2]}], got {tuple(key_padding_mask.shape)}"
)
# [B, S_kv] -> [B, 1, 1, S_kv] so SDPA broadcasts over H and S_q.
attn_mask = key_padding_mask[:, None, None, :]
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=self.scale)

return F.scaled_dot_product_attention(q, k, v, is_causal=is_causal, scale=self.scale)

@property
Expand Down
7 changes: 7 additions & 0 deletions tensorrt_llm/_torch/visual_gen/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ class ParallelConfig(StrictBaseModel):
dit_attn2d_col_size: int = PydanticField(1, ge=1) # Supported
dit_cfg_size: int = PydanticField(1, ge=1) # Supported
dit_fsdp_size: int = PydanticField(1, ge=1)
# When True, model-layer pads audio sequence at LTXModel.forward entry so
# T_a is divisible by ulysses_size, attaches a [B, T_a_padded] bool mask
# consumed by audio self-attn + a2v cross-attn (zero-attention on pad),
# and strips the padded tail on exit. Required to keep the Ulysses
# cross-attn wrapper engaged on non-divisible T_a (the typical LTX-2 case
# at num_frames=8k+1). Currently only honored by the LTX-2 model.
audio_pad_for_ulysses: bool = PydanticField(True)
dit_dim_order: str = PydanticField(
DEFAULT_DIM_ORDER,
description=(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ class TransformerArgs:
cross_scale_shift_timestep: torch.Tensor | None
cross_gate_timestep: torch.Tensor | None
enabled: bool
# Optional [B, S_full_padded] bool mask (True=valid, False=pad) for the
# audio modality when LTXModel.forward pads it on entry to satisfy
# T_a % ulysses_size == 0. Identical across Ulysses ranks (full-seq).
# None when no padding is applied.
audio_padding_mask: torch.Tensor | None = None


class TransformerArgsPreprocessor:
Expand Down
Loading
Loading