Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 239 additions & 35 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,8 @@ class _HubKernelConfig:
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_varlen_func",
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_varlen_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_varlen_backward",
version=1,
),
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
Expand Down Expand Up @@ -636,6 +638,13 @@ def _prepare_for_flash_attn_or_sage_varlen(
return _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, device)


def _unpad_to_padded(packed: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
"""scatter a packed `(nnz, ...)` tensor back to padded `(batch_size, seq_len, ...)`."""
output = torch.zeros(batch_size * seq_len, *packed.shape[1:], dtype=packed.dtype, device=packed.device)
output[indices] = packed
return output.view(batch_size, seq_len, *packed.shape[1:])


def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor:
"""
Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in
Expand Down Expand Up @@ -1292,6 +1301,178 @@ def _flash_attention_hub_backward_op(
return grad_query, grad_key, grad_value


def _flash_varlen_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: "ParallelConfig" | None = None,
*,
window_size: tuple[int, int] = (-1, -1),
):
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn varlen hub kernels.")

config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB]
wrapped_forward_fn = config.wrapped_forward_fn
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_forward_fn is None or wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention varlen hub kernels must expose `_wrapped_flash_attn_varlen_forward` and "
"`_wrapped_flash_attn_varlen_backward` for context parallel execution."
)

if scale is None:
scale = query.shape[-1] ** (-0.5)

softcap = 0.0
alibi_slopes = None
deterministic = False
grad_enabled = any(x.requires_grad for x in (query, key, value))

if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
dropout_p = dropout_p if dropout_p > 0 else 1e-30

batch_size, seq_len_q, num_heads, _ = query.shape
_, seq_len_kv, _, _ = key.shape

if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (_, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, query.device)
)
indices_k = attn_mask.flatten().nonzero(as_tuple=False).flatten()
query_packed = query.flatten(0, 1)
key_packed = key.reshape(-1, *key.shape[2:])[indices_k]
value_packed = value.reshape(-1, *value.shape[2:])[indices_k]
max_seqlen_q = seq_len_q
else:
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device)
)
query_packed = query.flatten(0, 1)
key_packed = key.flatten(0, 1)
value_packed = value.flatten(0, 1)
seqlens_k = None

with torch.set_grad_enabled(grad_enabled):
out_packed, lse, _, rng_state = wrapped_forward_fn(
query_packed,
key_packed,
value_packed,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
scale,
is_causal,
window_size[0],
window_size[1],
softcap,
alibi_slopes,
return_lse,
)

out = out_packed.view(batch_size, seq_len_q, *out_packed.shape[1:])

if _save_ctx:
ctx.save_for_backward(
query_packed, key_packed, value_packed, out_packed, lse, rng_state, cu_seqlens_q, cu_seqlens_k
)
ctx.seqlens_k = seqlens_k # None if unmasked
ctx.indices_k = indices_k if attn_mask is not None else None
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.batch_size = batch_size
ctx.seq_len_q = seq_len_q
ctx.seq_len_kv = seq_len_kv
ctx.num_heads = num_heads
ctx.dropout_p = dropout_p
ctx.scale = scale
ctx.is_causal = is_causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic

# (num_heads, batch_size * seq_len_q) -> (batch_size, seq_len_q, num_heads)
lse_sp = lse.view(num_heads, batch_size, seq_len_q).permute(1, 2, 0).contiguous()

return (out, lse_sp) if return_lse else out


def _flash_varlen_attention_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB]
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention varlen hub kernels must expose `_wrapped_flash_attn_varlen_backward` "
"for context parallel execution."
)

query_packed, key_packed, value_packed, out_packed, lse, rng_state, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors

grad_out_packed = grad_out.flatten(0, 1)
grad_query, grad_key, grad_value = (
torch.empty_like(query_packed),
torch.empty_like(key_packed),
torch.empty_like(value_packed),
)

_ = wrapped_backward_fn(
grad_out_packed,
query_packed,
key_packed,
value_packed,
out_packed,
lse,
grad_query,
grad_key,
grad_value,
cu_seqlens_q,
cu_seqlens_k,
ctx.max_seqlen_q,
ctx.max_seqlen_k,
ctx.dropout_p,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state,
)

grad_query = grad_query.view(ctx.batch_size, ctx.seq_len_q, *grad_query.shape[1:])

if ctx.seqlens_k is not None:
grad_key = _unpad_to_padded(grad_key, ctx.indices_k, ctx.batch_size, ctx.seq_len_kv)
grad_value = _unpad_to_padded(grad_value, ctx.indices_k, ctx.batch_size, ctx.seq_len_kv)
else:
grad_key = grad_key.view(ctx.batch_size, ctx.seq_len_kv, *grad_key.shape[1:])
grad_value = grad_value.view(ctx.batch_size, ctx.seq_len_kv, *grad_value.shape[1:])

grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]

return grad_query, grad_key, grad_value


def _flash_attention_3_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
Expand Down Expand Up @@ -2557,7 +2738,7 @@ def _flash_attention_hub(
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_VARLEN_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _flash_varlen_attention_hub(
query: torch.Tensor,
Expand All @@ -2571,46 +2752,69 @@ def _flash_varlen_attention_hub(
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if _parallel_config is not None and _parallel_config.context_parallel_config.ring_degree > 1:
raise NotImplementedError("`ring_degree > 1` is not yet supported for the FLASH_VARLEN_HUB backend.")

lse = None
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape

if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)

(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
)
if _parallel_config is None:
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
(_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, query.device)
)
indices_k = attn_mask.flatten().nonzero(as_tuple=False).flatten()
key_packed = key.reshape(-1, *key.shape[2:])[indices_k]
value_packed = value.reshape(-1, *value.shape[2:])[indices_k]
else:
(_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device)
)
key_packed = key.flatten(0, 1)
value_packed = value.flatten(0, 1)

key_valid, value_valid = [], []
for b in range(batch_size):
valid_len = seqlens_k[b]
key_valid.append(key[b, :valid_len])
value_valid.append(value[b, :valid_len])
query_packed = query.flatten(0, 1)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not supposed to guarded?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be guard implictly at the top

batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape


query_packed = query.flatten(0, 1)
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)

func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB].kernel_fn
out = func(
q=query_packed,
k=key_packed,
v=value_packed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
window_size=window_size,
return_attn_probs=return_lse,
)
out = out.unflatten(0, (batch_size, -1))
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB].kernel_fn
out = func(
q=query_packed,
k=key_packed,
v=value_packed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
window_size=window_size,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
out = out.unflatten(0, (batch_size, -1))
else:
forward_op = functools.partial(_flash_varlen_attention_hub_forward_op, window_size=window_size)
out = _templated_context_parallel_attention(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
False,
return_lse,
forward_op=forward_op,
backward_op=_flash_varlen_attention_hub_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out

return out
return (out, lse) if return_lse else out


@_AttentionBackendRegistry.register(
Expand Down
11 changes: 11 additions & 0 deletions tests/models/testing_utils/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,8 @@ def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names)
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelAttentionBackendsTesterMixin:
unsupported_attn_backends: list[str] = []

@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"])
@pytest.mark.parametrize(
"attention_backend",
Expand All @@ -383,6 +385,10 @@ class ContextParallelAttentionBackendsTesterMixin:
"flash_hub",
marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."),
),
pytest.param(
"flash_varlen_hub",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should varlen tests get their own testing mixin class?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the varlen kernel can handle all the cases supported by the non-varlen kernel. Personally, I prefer to put them together.

marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."),
),
pytest.param(
"_flash_3_hub",
marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."),
Expand All @@ -398,9 +404,14 @@ def test_context_parallel_attn_backend_inference(self, cp_type, attention_backen
if getattr(self.model_class, "_cp_plan", None) is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")

if attention_backend in self.unsupported_attn_backends:
pytest.skip(f"{attention_backend} is not supported for this model.")

if cp_type == "ring_degree":
if attention_backend == AttentionBackendName.NATIVE:
pytest.skip("Skipping test because ring isn't supported with native attention backend.")
elif attention_backend in ("flash_varlen_hub"):
pytest.skip("`ring_degree` is not yet supported for varlen attention hub kernels.")

if ulysses_anything and "ulysses" not in cp_type:
pytest.skip("Skipping test as ulysses anything needs the ulysses degree set.")
Expand Down
1 change: 1 addition & 0 deletions tests/models/testing_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
_BF16_REQUIRED_BACKENDS = {
AttentionBackendName._NATIVE_CUDNN,
AttentionBackendName.FLASH_HUB,
AttentionBackendName.FLASH_VARLEN_HUB,
AttentionBackendName._FLASH_3_HUB,
}

Expand Down
11 changes: 11 additions & 0 deletions tests/models/transformers/test_models_transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
ContextParallelAttentionBackendsTesterMixin,
ContextParallelTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
Expand Down Expand Up @@ -253,6 +254,16 @@ class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig,
"""Context Parallel inference tests for QwenImage Transformer."""


class TestQwenImageTransformerContextParallelAttnBackends(
QwenImageTransformerTesterConfig, ContextParallelAttentionBackendsTesterMixin
):
"""Context Parallel inference x attention backends tests for QwenImage Transformer"""

# QwenImage always passes a joint attention mask (text + image), which flash_hub and
# _flash_3_hub do not support.
unsupported_attn_backends = ["flash_hub", "_flash_3_hub"]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any not varlen attention backend would fail no? If so, I would rather do something like

if "varlen" not in attention_backend:
    pytest.skip(...)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like FluxPipeline, it can also support varlen kernels after this change.

I’m not sure what the most suitable place is to put this

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because Qwen uses masks? Also, do the underlying tests use non-contiguous masks?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because Qwen uses masks

yes, and I've made the comment more clear.

Also, do the underlying tests use non-contiguous masks?

No it is simply an all true mask. Do I need to update it?



class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for QwenImage Transformer."""

Expand Down
Loading