Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions comfy/ldm/qwen_image/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ def forward(
seq_img = hidden_states.shape[1]
seq_txt = encoder_hidden_states.shape[1]

transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()

# Project and reshape to BHND format (batch, heads, seq, dim)
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
Expand All @@ -167,15 +170,22 @@ def forward(
joint_key = torch.cat([txt_key, img_key], dim=2)
joint_value = torch.cat([txt_value, img_value], dim=2)

joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)

if encoder_hidden_states_mask is not None:
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
else:
attn_mask = None

extra_options["img_slice"] = [txt_query.shape[2], joint_query.shape[2]]
if "attn1_patch" in transformer_patches:
patch = transformer_patches["attn1_patch"]
for p in patch:
out = p(joint_query, joint_key, joint_value, pe=image_rotary_emb, attn_mask=encoder_hidden_states_mask, extra_options=extra_options)
joint_query, joint_key, joint_value, image_rotary_emb, encoder_hidden_states_mask = out.get("q", joint_query), out.get("k", joint_key), out.get("v", joint_value), out.get("pe", image_rotary_emb), out.get("attn_mask", encoder_hidden_states_mask)

joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)

joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attn_mask, transformer_options=transformer_options,
skip_reshape=True)
Expand Down Expand Up @@ -444,6 +454,7 @@ def _forward(

timestep_zero_index = None
if ref_latents is not None:
ref_num_tokens = []
h = 0
w = 0
index = 0
Expand Down Expand Up @@ -474,16 +485,16 @@ def _forward(
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
ref_num_tokens.append(kontext.shape[1])
if timestep_zero:
if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = num_embeds
transformer_options = transformer_options.copy()
transformer_options["reference_image_num_tokens"] = ref_num_tokens

txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids

hidden_states = self.img_in(hidden_states)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
Expand All @@ -495,6 +506,18 @@ def _forward(
patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {})

if "post_input" in patches:
for p in patches["post_input"]:
out = p({"img": hidden_states, "txt": encoder_hidden_states, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
img_ids = out["img_ids"]
txt_ids = out["txt_ids"]

ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids

transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.transformer_blocks):
Expand Down
22 changes: 22 additions & 0 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,27 @@ def model_patches_models(self):

return models

def model_patches_call_function(self, function_name="cleanup", arguments={}):
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], function_name):
getattr(patch_list[i], function_name)(**arguments)
if "patches_replace" in to:
patches = to["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], function_name):
getattr(patch_list[k], function_name)(**arguments)
if "model_function_wrapper" in self.model_options:
wrap_func = self.model_options["model_function_wrapper"]
if hasattr(wrap_func, function_name):
getattr(wrap_func, function_name)(**arguments)

def model_dtype(self):
if hasattr(self.model, "get_dtype"):
return self.model.get_dtype()
Expand Down Expand Up @@ -1062,6 +1083,7 @@ def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float3
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)

def cleanup(self):
self.model_patches_call_function(function_name="cleanup")
self.clean_hooks()
if hasattr(self.model, "current_patcher"):
self.model.current_patcher = None
Expand Down
Loading