Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,42 @@ def _sage_attention_backward_op(
raise NotImplementedError("Backward pass is not implemented for Sage attention.")


def _maybe_modify_attn_mask_npu(
query: torch.Tensor,
key: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None
):
# Skip Attention Mask if all values are 1, `None` mask can speedup the computation
if (attn_mask is not None and torch.all(attn_mask != 0)):
attn_mask = None

# Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k]
# https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md
if (
attn_mask is not None
and attn_mask.ndim == 2
and attn_mask.shape[0] == query.shape[0]
and attn_mask.shape[1] == key.shape[1]
):
B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1]
attn_mask = ~attn_mask.to(torch.bool)
attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous()

# Reshape Attention Mask: [batch_size, 1, 1, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k]
if (
attn_mask is not None
and attn_mask.ndim == 4
and attn_mask.shape[0] == query.shape[0]
and attn_mask.shape[-1] == key.shape[1]
and attn_mask.shape[-2] == 1
):
B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1]
attn_mask = ~attn_mask.to(torch.bool)
attn_mask = attn_mask.expand(B, 1, Sq, Skv)

return attn_mask


def _npu_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
Expand All @@ -1126,13 +1162,16 @@ def _npu_attention_forward_op(
_parallel_config: Optional["ParallelConfig"] = None,
):
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")

attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask)

out = npu_fusion_attention(
query,
key,
value,
query.size(2), # num_heads
atten_mask=attn_mask,
input_layout="BSND",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
Expand Down Expand Up @@ -2421,16 +2460,17 @@ def _native_npu_attention(
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for NPU attention")
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
if _parallel_config is None:
attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask)

out = npu_fusion_attention(
query,
key,
value,
query.size(2), # num_heads
atten_mask=attn_mask,
input_layout="BSND",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
Expand All @@ -2445,7 +2485,7 @@ def _native_npu_attention(
query,
key,
value,
None,
attn_mask,
dropout_p,
None,
scale,
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,11 @@ def compute_text_seq_len_from_mask(
position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long)
active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
has_active = encoder_hidden_states_mask.any(dim=1)
per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
per_sample_len = torch.where(
has_active,
active_positions.max(dim=1).values + 1,
torch.as_tensor(text_seq_len, device=encoder_hidden_states.device)
)
return text_seq_len, per_sample_len, encoder_hidden_states_mask


Expand Down
72 changes: 54 additions & 18 deletions src/diffusers/models/transformers/transformer_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ...models.attention_processor import Attention
from ...models.modeling_utils import ModelMixin
from ...models.normalization import RMSNorm
from ...utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_dispatch import dispatch_attention_fn
from ..modeling_outputs import Transformer2DModelOutput
Expand Down Expand Up @@ -323,37 +324,72 @@ def __init__(
self.axes_lens = axes_lens
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
self.freqs_cis = None
self.freqs_real = None
self.freqs_imag = None

@staticmethod
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
with torch.device("cpu"):
freqs_cis = []
for i, (d, e) in enumerate(zip(dim, end)):
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
freqs = torch.outer(timestep, freqs).float()
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
freqs_cis.append(freqs_cis_i)

return freqs_cis
if is_torch_npu_available:
freqs_real_list = []
freqs_imag_list = []
for i, (d, e) in enumerate(zip(dim, end)):
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
freqs = torch.outer(timestep, freqs).float()
freqs_real = torch.cos(freqs)
freqs_imag = torch.sin(freqs)
freqs_real_list.append(freqs_real.to(torch.float32))
freqs_imag_list.append(freqs_imag.to(torch.float32))

return freqs_real_list, freqs_imag_list
else:
freqs_cis = []
for i, (d, e) in enumerate(zip(dim, end)):
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
freqs = torch.outer(timestep, freqs).float()
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
freqs_cis.append(freqs_cis_i)
return freqs_cis

def __call__(self, ids: torch.Tensor):
assert ids.ndim == 2
assert ids.shape[-1] == len(self.axes_dims)
device = ids.device

if self.freqs_cis is None:
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
if is_torch_npu_available:
if self.freqs_real is None or self.freqs_imag is None:
freqs_real, freqs_imag = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
self.freqs_real = [fr.to(device) for fr in freqs_real]
self.freqs_imag = [fi.to(device) for fi in freqs_imag]
else:
# Ensure freqs_cis are on the same device as ids
if self.freqs_real[0].device != device:
self.freqs_real = [fr.to(device) for fr in freqs_real]
self.freqs_imag = [fi.to(device) for fi in freqs_imag]

result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
real_part = self.freqs_real[i][index]
imag_part = self.freqs_imag[i][index]
complex_part = torch.complex(real_part, imag_part)
result.append(complex_part)
else:
# Ensure freqs_cis are on the same device as ids
if self.freqs_cis[0].device != device:
if self.freqs_cis is None:
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
else:
# Ensure freqs_cis are on the same device as ids
if self.freqs_cis[0].device != device:
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]

result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
result.append(self.freqs_cis[i][index])

result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
result.append(self.freqs_cis[i][index])
return torch.cat(result, dim=-1)


Expand Down