1515
1616ops = comfy .ops .disable_weight_init
1717
18+ def in_meta_context ():
19+ return torch .device ("meta" ) == torch .empty (0 ).device
20+
1821def mark_conv3d_ended (module ):
1922 tid = threading .get_ident ()
2023 for _ , m in module .named_modules ():
@@ -350,6 +353,10 @@ def __init__(
350353 output_channel = output_channel * block_params .get ("multiplier" , 2 )
351354 if block_name == "compress_all" :
352355 output_channel = output_channel * block_params .get ("multiplier" , 1 )
356+ if block_name == "compress_space" :
357+ output_channel = output_channel * block_params .get ("multiplier" , 1 )
358+ if block_name == "compress_time" :
359+ output_channel = output_channel * block_params .get ("multiplier" , 1 )
353360
354361 self .conv_in = make_conv_nd (
355362 dims ,
@@ -395,17 +402,21 @@ def __init__(
395402 spatial_padding_mode = spatial_padding_mode ,
396403 )
397404 elif block_name == "compress_time" :
405+ output_channel = output_channel // block_params .get ("multiplier" , 1 )
398406 block = DepthToSpaceUpsample (
399407 dims = dims ,
400408 in_channels = input_channel ,
401409 stride = (2 , 1 , 1 ),
410+ out_channels_reduction_factor = block_params .get ("multiplier" , 1 ),
402411 spatial_padding_mode = spatial_padding_mode ,
403412 )
404413 elif block_name == "compress_space" :
414+ output_channel = output_channel // block_params .get ("multiplier" , 1 )
405415 block = DepthToSpaceUpsample (
406416 dims = dims ,
407417 in_channels = input_channel ,
408418 stride = (1 , 2 , 2 ),
419+ out_channels_reduction_factor = block_params .get ("multiplier" , 1 ),
409420 spatial_padding_mode = spatial_padding_mode ,
410421 )
411422 elif block_name == "compress_all" :
@@ -455,6 +466,15 @@ def __init__(
455466 output_channel * 2 , 0 , operations = ops ,
456467 )
457468 self .last_scale_shift_table = nn .Parameter (torch .empty (2 , output_channel ))
469+ else :
470+ self .register_buffer (
471+ "last_scale_shift_table" ,
472+ torch .tensor (
473+ [0.0 , 0.0 ],
474+ device = "cpu" if in_meta_context () else None
475+ ).unsqueeze (1 ).expand (2 , output_channel ),
476+ persistent = False ,
477+ )
458478
459479
460480 # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
@@ -883,6 +903,15 @@ def __init__(
883903 self .scale_shift_table = nn .Parameter (
884904 torch .randn (4 , in_channels ) / in_channels ** 0.5
885905 )
906+ else :
907+ self .register_buffer (
908+ "scale_shift_table" ,
909+ torch .tensor (
910+ [0.0 , 0.0 , 0.0 , 0.0 ],
911+ device = "cpu" if in_meta_context () else None
912+ ).unsqueeze (1 ).expand (4 , in_channels ),
913+ persistent = False ,
914+ )
886915
887916 self .temporal_cache_state = {}
888917
@@ -1012,9 +1041,6 @@ def __init__(self):
10121041 super ().__init__ ()
10131042 self .register_buffer ("std-of-means" , torch .empty (128 ))
10141043 self .register_buffer ("mean-of-means" , torch .empty (128 ))
1015- self .register_buffer ("mean-of-stds" , torch .empty (128 ))
1016- self .register_buffer ("mean-of-stds_over_std-of-means" , torch .empty (128 ))
1017- self .register_buffer ("channel" , torch .empty (128 ))
10181044
10191045 def un_normalize (self , x ):
10201046 return (x * self .get_buffer ("std-of-means" ).view (1 , - 1 , 1 , 1 , 1 ).to (x )) + self .get_buffer ("mean-of-means" ).view (1 , - 1 , 1 , 1 , 1 ).to (x )
@@ -1027,9 +1053,12 @@ def __init__(self, version=0, config=None):
10271053 super ().__init__ ()
10281054
10291055 if config is None :
1030- config = self .guess_config (version )
1056+ config = self .get_default_config (version )
10311057
1058+ self .config = config
10321059 self .timestep_conditioning = config .get ("timestep_conditioning" , False )
1060+ self .decode_noise_scale = config .get ("decode_noise_scale" , 0.025 )
1061+ self .decode_timestep = config .get ("decode_timestep" , 0.05 )
10331062 double_z = config .get ("double_z" , True )
10341063 latent_log_var = config .get (
10351064 "latent_log_var" , "per_channel" if double_z else "none"
@@ -1044,13 +1073,15 @@ def __init__(self, version=0, config=None):
10441073 latent_log_var = latent_log_var ,
10451074 norm_layer = config .get ("norm_layer" , "group_norm" ),
10461075 spatial_padding_mode = config .get ("spatial_padding_mode" , "zeros" ),
1076+ base_channels = config .get ("encoder_base_channels" , 128 ),
10471077 )
10481078
10491079 self .decoder = Decoder (
10501080 dims = config ["dims" ],
10511081 in_channels = config ["latent_channels" ],
10521082 out_channels = config .get ("out_channels" , 3 ),
10531083 blocks = config .get ("decoder_blocks" , config .get ("decoder_blocks" , config .get ("blocks" ))),
1084+ base_channels = config .get ("decoder_base_channels" , 128 ),
10541085 patch_size = config .get ("patch_size" , 1 ),
10551086 norm_layer = config .get ("norm_layer" , "group_norm" ),
10561087 causal = config .get ("causal_decoder" , False ),
@@ -1060,7 +1091,7 @@ def __init__(self, version=0, config=None):
10601091
10611092 self .per_channel_statistics = processor ()
10621093
1063- def guess_config (self , version ):
1094+ def get_default_config (self , version ):
10641095 if version == 0 :
10651096 config = {
10661097 "_class_name" : "CausalVideoAutoencoder" ,
@@ -1167,8 +1198,7 @@ def encode(self, x):
11671198 means , logvar = torch .chunk (self .encoder (x ), 2 , dim = 1 )
11681199 return self .per_channel_statistics .normalize (means )
11691200
1170- def decode (self , x , timestep = 0.05 , noise_scale = 0.025 ):
1201+ def decode (self , x ):
11711202 if self .timestep_conditioning : #TODO: seed
1172- x = torch .randn_like (x ) * noise_scale + (1.0 - noise_scale ) * x
1173- return self .decoder (self .per_channel_statistics .un_normalize (x ), timestep = timestep )
1174-
1203+ x = torch .randn_like (x ) * self .decode_noise_scale + (1.0 - self .decode_noise_scale ) * x
1204+ return self .decoder (self .per_channel_statistics .un_normalize (x ), timestep = self .decode_timestep )
0 commit comments