Skip to content

Commit f7bd5e5

Browse files
Make it easier to implement future qwen controlnets. (Comfy-Org#9485)
1 parent 7ed73d1 commit f7bd5e5

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

comfy/controlnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,11 +236,11 @@ def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
236236
self.cond_hint = None
237237
compression_ratio = self.compression_ratio
238238
if self.vae is not None:
239-
compression_ratio *= self.vae.downscale_ratio
239+
compression_ratio *= self.vae.spacial_compression_encode()
240240
else:
241241
if self.latent_format is not None:
242242
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
243-
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
243+
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[-1] * compression_ratio, x_noisy.shape[-2] * compression_ratio, self.upscale_algorithm, "center")
244244
self.cond_hint = self.preprocess_image(self.cond_hint)
245245
if self.vae is not None:
246246
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)

comfy/ldm/qwen_image/model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,13 +293,15 @@ def __init__(
293293
guidance_embeds: bool = False,
294294
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
295295
image_model=None,
296+
final_layer=True,
296297
dtype=None,
297298
device=None,
298299
operations=None,
299300
):
300301
super().__init__()
301302
self.dtype = dtype
302303
self.patch_size = patch_size
304+
self.in_channels = in_channels
303305
self.out_channels = out_channels or in_channels
304306
self.inner_dim = num_attention_heads * attention_head_dim
305307

@@ -329,9 +331,9 @@ def __init__(
329331
for _ in range(num_layers)
330332
])
331333

332-
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
333-
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
334-
self.gradient_checkpointing = False
334+
if final_layer:
335+
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
336+
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
335337

336338
def process_img(self, x, index=0, h_offset=0, w_offset=0):
337339
bs, c, t, h, w = x.shape
@@ -362,6 +364,7 @@ def forward(
362364
guidance: torch.Tensor = None,
363365
ref_latents=None,
364366
transformer_options={},
367+
control=None,
365368
**kwargs
366369
):
367370
timestep = timesteps
@@ -443,6 +446,13 @@ def block_wrap(args):
443446
hidden_states = out["img"]
444447
encoder_hidden_states = out["txt"]
445448

449+
if control is not None: # Controlnet
450+
control_i = control.get("input")
451+
if i < len(control_i):
452+
add = control_i[i]
453+
if add is not None:
454+
hidden_states += add
455+
446456
hidden_states = self.norm_out(hidden_states, temb)
447457
hidden_states = self.proj_out(hidden_states)
448458

comfy/model_detection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
492492
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
493493
dit_config = {}
494494
dit_config["image_model"] = "qwen_image"
495+
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
496+
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
495497
return dit_config
496498

497499
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:

0 commit comments

Comments
 (0)