Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,20 @@ def encode_prompt(
if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)

# NOTE:
# - QwenImageTransformer2DModel forward (and its RoPE embedder) requires an explicit text seq len
# (`max_txt_seq_len` / deprecated `txt_seq_lens`) for rotary embeddings.
# - Pipeline public API exposes `max_sequence_length`, but previous implementation did not apply it.
# Here we enforce it by truncating prompt embeds and masks.
if max_sequence_length is not None:
try:
msl = int(max_sequence_length)
except Exception:
msl = None
if msl is not None and msl > 0:
prompt_embeds = prompt_embeds[:, :msl, :]
prompt_embeds_mask = prompt_embeds_mask[:, :msl]

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
Expand Down Expand Up @@ -800,6 +814,29 @@ def __call__(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
with self.transformer.cache_context("cond"):
# Qwen transformer requires explicit text seq len for RoPE computation.
# Use the actual encoder_hidden_states sequence length (after truncation).
cond_txt_len = prompt_embeds.shape[1]
# NOTE: some accelerated/quantized transformer impls (e.g. nunchaku) may not accept
# `max_txt_seq_len` kwarg. We adapt to the actual forward signature:
# - prefer `max_txt_seq_len` when supported
# - fallback to deprecated `txt_seq_lens`
# - otherwise pass nothing and let the implementation handle it internally
_fwd_params = {}
try:
_fwd_params = inspect.signature(self.transformer.forward).parameters
except Exception:
_fwd_params = {}
_accepts_kwargs = any(
getattr(p, "kind", None) == inspect.Parameter.VAR_KEYWORD for p in _fwd_params.values()
)

_rope_kwargs = {}
if _accepts_kwargs or ("max_txt_seq_len" in _fwd_params):
_rope_kwargs["max_txt_seq_len"] = cond_txt_len
elif "txt_seq_lens" in _fwd_params:
_rope_kwargs["txt_seq_lens"] = [int(cond_txt_len)]

noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
Expand All @@ -809,11 +846,28 @@ def __call__(
img_shapes=img_shapes,
attention_kwargs=self.attention_kwargs,
return_dict=False,
**_rope_kwargs,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]

if do_true_cfg:
with self.transformer.cache_context("uncond"):
uncond_txt_len = negative_prompt_embeds.shape[1]
_fwd_params = {}
try:
_fwd_params = inspect.signature(self.transformer.forward).parameters
except Exception:
_fwd_params = {}
_accepts_kwargs = any(
getattr(p, "kind", None) == inspect.Parameter.VAR_KEYWORD for p in _fwd_params.values()
)

_rope_kwargs = {}
if _accepts_kwargs or ("max_txt_seq_len" in _fwd_params):
_rope_kwargs["max_txt_seq_len"] = uncond_txt_len
elif "txt_seq_lens" in _fwd_params:
_rope_kwargs["txt_seq_lens"] = [int(uncond_txt_len)]

neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
Expand All @@ -823,6 +877,7 @@ def __call__(
img_shapes=img_shapes,
attention_kwargs=self.attention_kwargs,
return_dict=False,
**_rope_kwargs,
)[0]
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
Expand Down