@@ -1255,6 +1255,7 @@ def forward_orig(
12551255 audio_emb = None
12561256
12571257 # embeddings
1258+ bs , _ , time , height , width = x .shape
12581259 x = self .patch_embedding (x .float ()).to (x .dtype )
12591260 if control_video is not None :
12601261 x = x + self .cond_encoder (control_video )
@@ -1272,7 +1273,7 @@ def forward_orig(
12721273 if reference_latent is not None :
12731274 ref = self .patch_embedding (reference_latent .float ()).to (x .dtype )
12741275 ref = ref .flatten (2 ).transpose (1 , 2 )
1275- freqs_ref = self .rope_encode (reference_latent .shape [- 3 ], reference_latent .shape [- 2 ], reference_latent .shape [- 1 ], t_start = 30 , device = x .device , dtype = x .dtype )
1276+ freqs_ref = self .rope_encode (reference_latent .shape [- 3 ], reference_latent .shape [- 2 ], reference_latent .shape [- 1 ], t_start = max ( 30 , time + 9 ) , device = x .device , dtype = x .dtype )
12761277 ref = ref + cond_mask_weight [1 ]
12771278 x = torch .cat ([x , ref ], dim = 1 )
12781279 freqs = torch .cat ([freqs , freqs_ref ], dim = 1 )
@@ -1296,7 +1297,6 @@ def forward_orig(
12961297 # context
12971298 context = self .text_embedding (context )
12981299
1299-
13001300 patches_replace = transformer_options .get ("patches_replace" , {})
13011301 blocks_replace = patches_replace .get ("dit" , {})
13021302 for i , block in enumerate (self .blocks ):
0 commit comments