@@ -293,13 +293,15 @@ def __init__(
293293 guidance_embeds : bool = False ,
294294 axes_dims_rope : Tuple [int , int , int ] = (16 , 56 , 56 ),
295295 image_model = None ,
296+ final_layer = True ,
296297 dtype = None ,
297298 device = None ,
298299 operations = None ,
299300 ):
300301 super ().__init__ ()
301302 self .dtype = dtype
302303 self .patch_size = patch_size
304+ self .in_channels = in_channels
303305 self .out_channels = out_channels or in_channels
304306 self .inner_dim = num_attention_heads * attention_head_dim
305307
@@ -329,9 +331,9 @@ def __init__(
329331 for _ in range (num_layers )
330332 ])
331333
332- self . norm_out = LastLayer ( self . inner_dim , self . inner_dim , dtype = dtype , device = device , operations = operations )
333- self .proj_out = operations . Linear (self .inner_dim , patch_size * patch_size * self .out_channels , bias = True , dtype = dtype , device = device )
334- self .gradient_checkpointing = False
334+ if final_layer :
335+ self .norm_out = LastLayer (self .inner_dim , self .inner_dim , dtype = dtype , device = device , operations = operations )
336+ self .proj_out = operations . Linear ( self . inner_dim , patch_size * patch_size * self . out_channels , bias = True , dtype = dtype , device = device )
335337
336338 def process_img (self , x , index = 0 , h_offset = 0 , w_offset = 0 ):
337339 bs , c , t , h , w = x .shape
@@ -362,6 +364,7 @@ def forward(
362364 guidance : torch .Tensor = None ,
363365 ref_latents = None ,
364366 transformer_options = {},
367+ control = None ,
365368 ** kwargs
366369 ):
367370 timestep = timesteps
@@ -443,6 +446,13 @@ def block_wrap(args):
443446 hidden_states = out ["img" ]
444447 encoder_hidden_states = out ["txt" ]
445448
449+ if control is not None : # Controlnet
450+ control_i = control .get ("input" )
451+ if i < len (control_i ):
452+ add = control_i [i ]
453+ if add is not None :
454+ hidden_states += add
455+
446456 hidden_states = self .norm_out (hidden_states , temb )
447457 hidden_states = self .proj_out (hidden_states )
448458
0 commit comments