Skip to content

Commit 496888f

Browse files
Improve s2v performance when generating videos longer than 120 frames. (Comfy-Org#9582)
1 parent b5ac6ed commit 496888f

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

comfy/ldm/wan/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,6 +1255,7 @@ def forward_orig(
12551255
audio_emb = None
12561256

12571257
# embeddings
1258+
bs, _, time, height, width = x.shape
12581259
x = self.patch_embedding(x.float()).to(x.dtype)
12591260
if control_video is not None:
12601261
x = x + self.cond_encoder(control_video)
@@ -1272,7 +1273,7 @@ def forward_orig(
12721273
if reference_latent is not None:
12731274
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
12741275
ref = ref.flatten(2).transpose(1, 2)
1275-
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=30, device=x.device, dtype=x.dtype)
1276+
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=max(30, time + 9), device=x.device, dtype=x.dtype)
12761277
ref = ref + cond_mask_weight[1]
12771278
x = torch.cat([x, ref], dim=1)
12781279
freqs = torch.cat([freqs, freqs_ref], dim=1)
@@ -1296,7 +1297,6 @@ def forward_orig(
12961297
# context
12971298
context = self.text_embedding(context)
12981299

1299-
13001300
patches_replace = transformer_options.get("patches_replace", {})
13011301
blocks_replace = patches_replace.get("dit", {})
13021302
for i, block in enumerate(self.blocks):

0 commit comments

Comments
 (0)