Skip to content

Commit 9642e44

Browse files
Add pre attention and post input patches to qwen image model. (Comfy-Org#12879)
1 parent 3ad36d6 commit 9642e44

File tree

1 file changed

+29
-6
lines changed

1 file changed

+29
-6
lines changed

comfy/ldm/qwen_image/model.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ def forward(
149149
seq_img = hidden_states.shape[1]
150150
seq_txt = encoder_hidden_states.shape[1]
151151

152+
transformer_patches = transformer_options.get("patches", {})
153+
extra_options = transformer_options.copy()
154+
152155
# Project and reshape to BHND format (batch, heads, seq, dim)
153156
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
154157
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
@@ -167,15 +170,22 @@ def forward(
167170
joint_key = torch.cat([txt_key, img_key], dim=2)
168171
joint_value = torch.cat([txt_value, img_value], dim=2)
169172

170-
joint_query = apply_rope1(joint_query, image_rotary_emb)
171-
joint_key = apply_rope1(joint_key, image_rotary_emb)
172-
173173
if encoder_hidden_states_mask is not None:
174174
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
175175
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
176176
else:
177177
attn_mask = None
178178

179+
extra_options["img_slice"] = [txt_query.shape[2], joint_query.shape[2]]
180+
if "attn1_patch" in transformer_patches:
181+
patch = transformer_patches["attn1_patch"]
182+
for p in patch:
183+
out = p(joint_query, joint_key, joint_value, pe=image_rotary_emb, attn_mask=encoder_hidden_states_mask, extra_options=extra_options)
184+
joint_query, joint_key, joint_value, image_rotary_emb, encoder_hidden_states_mask = out.get("q", joint_query), out.get("k", joint_key), out.get("v", joint_value), out.get("pe", image_rotary_emb), out.get("attn_mask", encoder_hidden_states_mask)
185+
186+
joint_query = apply_rope1(joint_query, image_rotary_emb)
187+
joint_key = apply_rope1(joint_key, image_rotary_emb)
188+
179189
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
180190
attn_mask, transformer_options=transformer_options,
181191
skip_reshape=True)
@@ -444,6 +454,7 @@ def _forward(
444454

445455
timestep_zero_index = None
446456
if ref_latents is not None:
457+
ref_num_tokens = []
447458
h = 0
448459
w = 0
449460
index = 0
@@ -474,16 +485,16 @@ def _forward(
474485
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
475486
hidden_states = torch.cat([hidden_states, kontext], dim=1)
476487
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
488+
ref_num_tokens.append(kontext.shape[1])
477489
if timestep_zero:
478490
if index > 0:
479491
timestep = torch.cat([timestep, timestep * 0], dim=0)
480492
timestep_zero_index = num_embeds
493+
transformer_options = transformer_options.copy()
494+
transformer_options["reference_image_num_tokens"] = ref_num_tokens
481495

482496
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
483497
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
484-
ids = torch.cat((txt_ids, img_ids), dim=1)
485-
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
486-
del ids, txt_ids, img_ids
487498

488499
hidden_states = self.img_in(hidden_states)
489500
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
@@ -495,6 +506,18 @@ def _forward(
495506
patches = transformer_options.get("patches", {})
496507
blocks_replace = patches_replace.get("dit", {})
497508

509+
if "post_input" in patches:
510+
for p in patches["post_input"]:
511+
out = p({"img": hidden_states, "txt": encoder_hidden_states, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
512+
hidden_states = out["img"]
513+
encoder_hidden_states = out["txt"]
514+
img_ids = out["img_ids"]
515+
txt_ids = out["txt_ids"]
516+
517+
ids = torch.cat((txt_ids, img_ids), dim=1)
518+
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
519+
del ids, txt_ids, img_ids
520+
498521
transformer_options["total_blocks"] = len(self.transformer_blocks)
499522
transformer_options["block_type"] = "double"
500523
for i, block in enumerate(self.transformer_blocks):

0 commit comments

Comments
 (0)