Skip to content

Commit e80a14a

Browse files
Support wan2.2 5B fun control model. (Comfy-Org#9611)
Use the Wan22FunControlToVideo node.
1 parent d28b39d commit e80a14a

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

comfy/model_base.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,9 +1110,10 @@ def concat_cond(self, **kwargs):
11101110
shape_image[1] = extra_channels
11111111
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
11121112
else:
1113+
latent_dim = self.latent_format.latent_channels
11131114
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
1114-
for i in range(0, image.shape[1], 16):
1115-
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
1115+
for i in range(0, image.shape[1], latent_dim):
1116+
image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim])
11161117
image = utils.resize_to_batch_size(image, noise.shape[0])
11171118

11181119
if extra_channels != image.shape[1] + 4:
@@ -1245,18 +1246,14 @@ def extra_conds_shapes(self, **kwargs):
12451246
out['reference_motion'] = reference_motion.shape
12461247
return out
12471248

1248-
class WAN22(BaseModel):
1249+
class WAN22(WAN21):
12491250
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
1250-
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
1251+
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
12511252
self.image_to_video = image_to_video
12521253

12531254
def extra_conds(self, **kwargs):
12541255
out = super().extra_conds(**kwargs)
1255-
cross_attn = kwargs.get("cross_attn", None)
1256-
if cross_attn is not None:
1257-
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
1258-
1259-
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
1256+
denoise_mask = kwargs.get("denoise_mask", None)
12601257
if denoise_mask is not None:
12611258
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
12621259
return out

comfy_extras/nodes_wan.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,16 +139,21 @@ def define_schema(cls):
139139

140140
@classmethod
141141
def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None) -> io.NodeOutput:
142-
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
143-
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
144-
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
142+
spacial_scale = vae.spacial_compression_encode()
143+
latent_channels = vae.latent_channels
144+
latent = torch.zeros([batch_size, latent_channels, ((length - 1) // 4) + 1, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device())
145+
concat_latent = torch.zeros([batch_size, latent_channels, ((length - 1) // 4) + 1, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device())
146+
if latent_channels == 48:
147+
concat_latent = comfy.latent_formats.Wan22().process_out(concat_latent)
148+
else:
149+
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
145150
concat_latent = concat_latent.repeat(1, 2, 1, 1, 1)
146151
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
147152

148153
if start_image is not None:
149154
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
150155
concat_latent_image = vae.encode(start_image[:, :, :, :3])
151-
concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
156+
concat_latent[:,latent_channels:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
152157
mask[:, :, :start_image.shape[0] + 3] = 0.0
153158

154159
ref_latent = None
@@ -159,11 +164,11 @@ def execute(cls, positive, negative, vae, width, height, length, batch_size, ref
159164
if control_video is not None:
160165
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
161166
concat_latent_image = vae.encode(control_video[:, :, :, :3])
162-
concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
167+
concat_latent[:,:latent_channels,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
163168

164169
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
165-
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16})
166-
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16})
170+
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": latent_channels})
171+
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": latent_channels})
167172

168173
if ref_latent is not None:
169174
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)

0 commit comments

Comments
 (0)