Skip to content

Commit 0963493

Browse files
Support for Qwen Diffsynth Controlnets canny and depth. (Comfy-Org#9465)
These are not real controlnets but actually a patch on the model so they will be treated as such. Put them in the models/model_patches/ folder. Use the new ModelPatchLoader and QwenImageDiffsynthControlnet nodes.
1 parent e73a9db commit 0963493

7 files changed

Lines changed: 184 additions & 1 deletion

File tree

comfy/ldm/qwen_image/model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ def forward(
416416
)
417417

418418
patches_replace = transformer_options.get("patches_replace", {})
419+
patches = transformer_options.get("patches", {})
419420
blocks_replace = patches_replace.get("dit", {})
420421

421422
for i, block in enumerate(self.transformer_blocks):
@@ -436,6 +437,12 @@ def block_wrap(args):
436437
image_rotary_emb=image_rotary_emb,
437438
)
438439

440+
if "double_block" in patches:
441+
for p in patches["double_block"]:
442+
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i})
443+
hidden_states = out["img"]
444+
encoder_hidden_states = out["txt"]
445+
439446
hidden_states = self.norm_out(hidden_states, temb)
440447
hidden_states = self.proj_out(hidden_states)
441448

comfy/model_management.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,13 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
593593
else:
594594
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
595595

596-
models = set(models)
596+
models_temp = set()
597+
for m in models:
598+
models_temp.add(m)
599+
for mm in m.model_patches_models():
600+
models_temp.add(mm)
601+
602+
models = models_temp
597603

598604
models_to_load = []
599605

comfy/model_patcher.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,9 @@ def set_model_emb_patch(self, patch):
430430
def set_model_forward_timestep_embed_patch(self, patch):
431431
self.set_model_patch(patch, "forward_timestep_embed_patch")
432432

433+
def set_model_double_block_patch(self, patch):
434+
self.set_model_patch(patch, "double_block")
435+
433436
def add_object_patch(self, name, obj):
434437
self.object_patches[name] = obj
435438

@@ -486,6 +489,30 @@ def model_patches_to(self, device):
486489
if hasattr(wrap_func, "to"):
487490
self.model_options["model_function_wrapper"] = wrap_func.to(device)
488491

492+
def model_patches_models(self):
493+
to = self.model_options["transformer_options"]
494+
models = []
495+
if "patches" in to:
496+
patches = to["patches"]
497+
for name in patches:
498+
patch_list = patches[name]
499+
for i in range(len(patch_list)):
500+
if hasattr(patch_list[i], "models"):
501+
models += patch_list[i].models()
502+
if "patches_replace" in to:
503+
patches = to["patches_replace"]
504+
for name in patches:
505+
patch_list = patches[name]
506+
for k in patch_list:
507+
if hasattr(patch_list[k], "models"):
508+
models += patch_list[k].models()
509+
if "model_function_wrapper" in self.model_options:
510+
wrap_func = self.model_options["model_function_wrapper"]
511+
if hasattr(wrap_func, "models"):
512+
models += wrap_func.models()
513+
514+
return models
515+
489516
def model_dtype(self):
490517
if hasattr(self.model, "get_dtype"):
491518
return self.model.get_dtype()

comfy_api/latest/_io.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,10 @@ class SEGS(ComfyTypeIO):
726726
class AnyType(ComfyTypeIO):
727727
Type = Any
728728

729+
@comfytype(io_type="MODEL_PATCH")
730+
class MODEL_PATCH(ComfyTypeIO):
731+
Type = Any
732+
729733
@comfytype(io_type="COMFY_MULTITYPED_V3")
730734
class MultiType:
731735
Type = Any

comfy_extras/nodes_model_patch.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import torch
2+
import folder_paths
3+
import comfy.utils
4+
import comfy.ops
5+
import comfy.model_management
6+
import comfy.ldm.common_dit
7+
import comfy.latent_formats
8+
9+
10+
class BlockWiseControlBlock(torch.nn.Module):
11+
# [linear, gelu, linear]
12+
def __init__(self, dim: int = 3072, device=None, dtype=None, operations=None):
13+
super().__init__()
14+
self.x_rms = operations.RMSNorm(dim, eps=1e-6)
15+
self.y_rms = operations.RMSNorm(dim, eps=1e-6)
16+
self.input_proj = operations.Linear(dim, dim)
17+
self.act = torch.nn.GELU()
18+
self.output_proj = operations.Linear(dim, dim)
19+
20+
def forward(self, x, y):
21+
x, y = self.x_rms(x), self.y_rms(y)
22+
x = self.input_proj(x + y)
23+
x = self.act(x)
24+
x = self.output_proj(x)
25+
return x
26+
27+
28+
class QwenImageBlockWiseControlNet(torch.nn.Module):
29+
def __init__(
30+
self,
31+
num_layers: int = 60,
32+
in_dim: int = 64,
33+
additional_in_dim: int = 0,
34+
dim: int = 3072,
35+
device=None, dtype=None, operations=None
36+
):
37+
super().__init__()
38+
self.img_in = operations.Linear(in_dim + additional_in_dim, dim, device=device, dtype=dtype)
39+
self.controlnet_blocks = torch.nn.ModuleList(
40+
[
41+
BlockWiseControlBlock(dim, device=device, dtype=dtype, operations=operations)
42+
for _ in range(num_layers)
43+
]
44+
)
45+
46+
def process_input_latent_image(self, latent_image):
47+
latent_image = comfy.latent_formats.Wan21().process_in(latent_image)
48+
patch_size = 2
49+
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(latent_image, (1, patch_size, patch_size))
50+
orig_shape = hidden_states.shape
51+
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
52+
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
53+
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
54+
return self.img_in(hidden_states)
55+
56+
def control_block(self, img, controlnet_conditioning, block_id):
57+
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
58+
59+
60+
class ModelPatchLoader:
61+
@classmethod
62+
def INPUT_TYPES(s):
63+
return {"required": { "name": (folder_paths.get_filename_list("model_patches"), ),
64+
}}
65+
RETURN_TYPES = ("MODEL_PATCH",)
66+
FUNCTION = "load_model_patch"
67+
EXPERIMENTAL = True
68+
69+
CATEGORY = "advanced/loaders"
70+
71+
def load_model_patch(self, name):
72+
model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name)
73+
sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True)
74+
dtype = comfy.utils.weight_dtype(sd)
75+
# TODO: this node will work with more types of model patches
76+
model = QwenImageBlockWiseControlNet(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
77+
model.load_state_dict(sd)
78+
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
79+
return (model,)
80+
81+
82+
class DiffSynthCnetPatch:
83+
def __init__(self, model_patch, vae, image, strength):
84+
self.encoded_image = model_patch.model.process_input_latent_image(vae.encode(image))
85+
self.model_patch = model_patch
86+
self.vae = vae
87+
self.image = image
88+
self.strength = strength
89+
90+
def __call__(self, kwargs):
91+
x = kwargs.get("x")
92+
img = kwargs.get("img")
93+
block_index = kwargs.get("block_index")
94+
if self.encoded_image is None or self.encoded_image.shape[1:] != img.shape[1:]:
95+
spacial_compression = self.vae.spacial_compression_encode()
96+
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
97+
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
98+
self.encoded_image = self.model_patch.model.process_input_latent_image(self.vae.encode(image_scaled.movedim(1, -1)))
99+
comfy.model_management.load_models_gpu(loaded_models)
100+
101+
img = img + (self.model_patch.model.control_block(img, self.encoded_image.to(img.dtype), block_index) * self.strength)
102+
kwargs['img'] = img
103+
return kwargs
104+
105+
def to(self, device_or_dtype):
106+
if isinstance(device_or_dtype, torch.device):
107+
self.encoded_image = self.encoded_image.to(device_or_dtype)
108+
return self
109+
110+
def models(self):
111+
return [self.model_patch]
112+
113+
class QwenImageDiffsynthControlnet:
114+
@classmethod
115+
def INPUT_TYPES(s):
116+
return {"required": { "model": ("MODEL",),
117+
"model_patch": ("MODEL_PATCH",),
118+
"vae": ("VAE",),
119+
"image": ("IMAGE",),
120+
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
121+
}}
122+
RETURN_TYPES = ("MODEL",)
123+
FUNCTION = "diffsynth_controlnet"
124+
EXPERIMENTAL = True
125+
126+
CATEGORY = "advanced/loaders/qwen"
127+
128+
def diffsynth_controlnet(self, model, model_patch, vae, image, strength):
129+
model_patched = model.clone()
130+
image = image[:, :, :, :3]
131+
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength))
132+
return (model_patched,)
133+
134+
135+
NODE_CLASS_MAPPINGS = {
136+
"ModelPatchLoader": ModelPatchLoader,
137+
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
138+
}

models/model_patches/put_model_patches_here

Whitespace-only changes.

nodes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2322,6 +2322,7 @@ async def init_builtin_extra_nodes():
23222322
"nodes_tcfg.py",
23232323
"nodes_context_windows.py",
23242324
"nodes_qwen.py",
2325+
"nodes_model_patch.py"
23252326
]
23262327

23272328
import_failed = []

0 commit comments

Comments
 (0)