Skip to content

Commit 1702e6d

Browse files
Implement wan2.2 camera model. (Comfy-Org#9357)
Use the old WanCameraImageToVideo node.
1 parent c308a88 commit 1702e6d

4 files changed

Lines changed: 28 additions & 5 deletions

File tree

comfy/ldm/wan/model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,12 @@ def __init__(self,
768768
operations=None,
769769
):
770770

771-
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
771+
if model_type == 'camera':
772+
model_type = 'i2v'
773+
else:
774+
model_type = 't2v'
775+
776+
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
772777
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
773778

774779
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)

comfy/model_detection.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
364364
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
365365
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
366366
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
367-
dit_config["model_type"] = "camera"
367+
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
368+
dit_config["model_type"] = "camera"
369+
else:
370+
dit_config["model_type"] = "camera_2.2"
368371
else:
369372
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
370373
dit_config["model_type"] = "i2v"

comfy/supported_models.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1046,6 +1046,18 @@ class WAN21_Camera(WAN21_T2V):
10461046
def get_model(self, state_dict, prefix="", device=None):
10471047
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
10481048
return out
1049+
1050+
class WAN22_Camera(WAN21_T2V):
1051+
unet_config = {
1052+
"image_model": "wan2.1",
1053+
"model_type": "camera_2.2",
1054+
"in_dim": 36,
1055+
}
1056+
1057+
def get_model(self, state_dict, prefix="", device=None):
1058+
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
1059+
return out
1060+
10491061
class WAN21_Vace(WAN21_T2V):
10501062
unet_config = {
10511063
"image_model": "wan2.1",
@@ -1260,6 +1272,6 @@ def clip_target(self, state_dict={}):
12601272
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
12611273

12621274

1263-
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
1275+
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
12641276

12651277
models += [SVD_img2vid]

comfy_extras/nodes_wan.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,9 +422,12 @@ def execute(cls, positive, negative, vae, width, height, length, batch_size, sta
422422
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
423423
concat_latent_image = vae.encode(start_image[:, :, :, :3])
424424
concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
425+
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
426+
mask[:, :, :start_image.shape[0] + 3] = 0.0
427+
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
425428

426-
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
427-
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
429+
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask})
430+
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask})
428431

429432
if camera_conditions is not None:
430433
positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions})

0 commit comments

Comments
 (0)