Skip to content

Commit 1e638a1

Browse files
Tiny wan vae optimizations. (Comfy-Org#9136)
1 parent 4696d74 commit 1e638a1

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

comfy/ldm/wan/vae.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,17 @@ def __init__(self, *args, **kwargs):
2424
self.padding[1], 2 * self.padding[0], 0)
2525
self.padding = (0, 0, 0)
2626

27-
def forward(self, x, cache_x=None):
27+
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
28+
if cache_list is not None:
29+
cache_x = cache_list[cache_idx]
30+
cache_list[cache_idx] = None
31+
2832
padding = list(self._padding)
2933
if cache_x is not None and self._padding[4] > 0:
3034
cache_x = cache_x.to(x.device)
3135
x = torch.cat([cache_x, x], dim=2)
3236
padding[4] -= cache_x.shape[2]
37+
del cache_x
3338
x = F.pad(x, padding)
3439

3540
return super().forward(x)
@@ -166,7 +171,7 @@ def __init__(self, in_dim, out_dim, dropout=0.0):
166171
if in_dim != out_dim else nn.Identity()
167172

168173
def forward(self, x, feat_cache=None, feat_idx=[0]):
169-
h = self.shortcut(x)
174+
old_x = x
170175
for layer in self.residual:
171176
if isinstance(layer, CausalConv3d) and feat_cache is not None:
172177
idx = feat_idx[0]
@@ -178,12 +183,12 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
178183
cache_x.device), cache_x
179184
],
180185
dim=2)
181-
x = layer(x, feat_cache[idx])
186+
x = layer(x, cache_list=feat_cache, cache_idx=idx)
182187
feat_cache[idx] = cache_x
183188
feat_idx[0] += 1
184189
else:
185190
x = layer(x)
186-
return x + h
191+
return x + self.shortcut(old_x)
187192

188193

189194
class AttentionBlock(nn.Module):

comfy/ldm/wan/vae2_2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
151151
],
152152
dim=2,
153153
)
154-
x = layer(x, feat_cache[idx])
154+
x = layer(x, cache_list=feat_cache, cache_idx=idx)
155155
feat_cache[idx] = cache_x
156156
feat_idx[0] += 1
157157
else:

0 commit comments

Comments
 (0)