@@ -139,16 +139,21 @@ def define_schema(cls):
139139
140140 @classmethod
141141 def execute (cls , positive , negative , vae , width , height , length , batch_size , ref_image = None , start_image = None , control_video = None ) -> io .NodeOutput :
142- latent = torch .zeros ([batch_size , 16 , ((length - 1 ) // 4 ) + 1 , height // 8 , width // 8 ], device = comfy .model_management .intermediate_device ())
143- concat_latent = torch .zeros ([batch_size , 16 , ((length - 1 ) // 4 ) + 1 , height // 8 , width // 8 ], device = comfy .model_management .intermediate_device ())
144- concat_latent = comfy .latent_formats .Wan21 ().process_out (concat_latent )
142+ spacial_scale = vae .spacial_compression_encode ()
143+ latent_channels = vae .latent_channels
144+ latent = torch .zeros ([batch_size , latent_channels , ((length - 1 ) // 4 ) + 1 , height // spacial_scale , width // spacial_scale ], device = comfy .model_management .intermediate_device ())
145+ concat_latent = torch .zeros ([batch_size , latent_channels , ((length - 1 ) // 4 ) + 1 , height // spacial_scale , width // spacial_scale ], device = comfy .model_management .intermediate_device ())
146+ if latent_channels == 48 :
147+ concat_latent = comfy .latent_formats .Wan22 ().process_out (concat_latent )
148+ else :
149+ concat_latent = comfy .latent_formats .Wan21 ().process_out (concat_latent )
145150 concat_latent = concat_latent .repeat (1 , 2 , 1 , 1 , 1 )
146151 mask = torch .ones ((1 , 1 , latent .shape [2 ] * 4 , latent .shape [- 2 ], latent .shape [- 1 ]))
147152
148153 if start_image is not None :
149154 start_image = comfy .utils .common_upscale (start_image [:length ].movedim (- 1 , 1 ), width , height , "bilinear" , "center" ).movedim (1 , - 1 )
150155 concat_latent_image = vae .encode (start_image [:, :, :, :3 ])
151- concat_latent [:,16 :,:concat_latent_image .shape [2 ]] = concat_latent_image [:,:,:concat_latent .shape [2 ]]
156+ concat_latent [:,latent_channels :,:concat_latent_image .shape [2 ]] = concat_latent_image [:,:,:concat_latent .shape [2 ]]
152157 mask [:, :, :start_image .shape [0 ] + 3 ] = 0.0
153158
154159 ref_latent = None
@@ -159,11 +164,11 @@ def execute(cls, positive, negative, vae, width, height, length, batch_size, ref
159164 if control_video is not None :
160165 control_video = comfy .utils .common_upscale (control_video [:length ].movedim (- 1 , 1 ), width , height , "bilinear" , "center" ).movedim (1 , - 1 )
161166 concat_latent_image = vae .encode (control_video [:, :, :, :3 ])
162- concat_latent [:,:16 ,:concat_latent_image .shape [2 ]] = concat_latent_image [:,:,:concat_latent .shape [2 ]]
167+ concat_latent [:,:latent_channels ,:concat_latent_image .shape [2 ]] = concat_latent_image [:,:,:concat_latent .shape [2 ]]
163168
164169 mask = mask .view (1 , mask .shape [2 ] // 4 , 4 , mask .shape [3 ], mask .shape [4 ]).transpose (1 , 2 )
165- positive = node_helpers .conditioning_set_values (positive , {"concat_latent_image" : concat_latent , "concat_mask" : mask , "concat_mask_index" : 16 })
166- negative = node_helpers .conditioning_set_values (negative , {"concat_latent_image" : concat_latent , "concat_mask" : mask , "concat_mask_index" : 16 })
170+ positive = node_helpers .conditioning_set_values (positive , {"concat_latent_image" : concat_latent , "concat_mask" : mask , "concat_mask_index" : latent_channels })
171+ negative = node_helpers .conditioning_set_values (negative , {"concat_latent_image" : concat_latent , "concat_mask" : mask , "concat_mask_index" : latent_channels })
167172
168173 if ref_latent is not None :
169174 positive = node_helpers .conditioning_set_values (positive , {"reference_latents" : [ref_latent ]}, append = True )
0 commit comments