Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions comfy/ldm/lightricks/vae/causal_conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,13 @@ def forward(self, x, causal: bool = True):
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)

x = torch.cat(pieces, dim=2)
del pieces
del cached

if needs_caching:
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
elif is_end:
self.temporal_cache_state[tid] = (None, True)

return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]

Expand Down
41 changes: 34 additions & 7 deletions comfy/ldm/lightricks/vae/causal_video_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,23 @@ def forward(self, *args, **kwargs):
module.temporal_cache_state.pop(tid, None)


MAX_CHUNK_SIZE=(128 * 1024 ** 2)
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3
MIN_CHUNK_SIZE = 32 * 1024 ** 2
MAX_CHUNK_SIZE = 128 * 1024 ** 2

def get_max_chunk_size(device: torch.device) -> int:
total_memory = comfy.model_management.get_total_memory(dev=device)

if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING:
return MIN_CHUNK_SIZE
if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING:
return MAX_CHUNK_SIZE

interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / (
MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING
)
return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE))

class Decoder(nn.Module):
r"""
Expand Down Expand Up @@ -525,8 +541,11 @@ def forward_orig(
timestep_shift_scale = ada_values.unbind(dim=1)

output = []
max_chunk_size = get_max_chunk_size(sample.device)

def run_up(idx, sample, ended):
def run_up(idx, sample_ref, ended):
sample = sample_ref[0]
sample_ref[0] = None
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
Expand Down Expand Up @@ -554,13 +573,21 @@ def run_up(idx, sample, ended):
return

total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size

if num_chunks == 1:
# when we are not chunking, detach our x so the callee can free it as soon as they are done
next_sample_ref = [sample]
del sample
run_up(idx + 1, next_sample_ref, ended)
return
else:
samples = torch.chunk(sample, chunks=num_chunks, dim=2)

for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)

run_up(0, sample, True)
run_up(0, [sample], True)
sample = torch.cat(output, dim=2)

sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
Expand Down
176 changes: 79 additions & 97 deletions comfy/ldm/wan/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(self, dim, mode):
else:
self.resample = nn.Identity()

def forward(self, x, feat_cache=None, feat_idx=[0]):
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if feat_cache is not None:
Expand All @@ -109,22 +109,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
feat_idx[0] += 1
else:

cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
cache_x = x[:, :, -CACHE_T:, :, :]
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
Expand All @@ -145,19 +130,24 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
feat_cache[idx] = x
else:

cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)

cache_x = x[:, :, -1:, :, :]
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1

deferred_x = feat_cache[idx + 1]
if deferred_x is not None:
x = torch.cat([deferred_x, x], 2)
feat_cache[idx + 1] = None

if x.shape[2] == 1 and not final:
feat_cache[idx + 1] = x
x = None

feat_idx[0] += 2
return x


Expand All @@ -177,19 +167,12 @@ def __init__(self, in_dim, out_dim, dropout=0.0):
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()

def forward(self, x, feat_cache=None, feat_idx=[0]):
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
old_x = x
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, cache_list=feat_cache, cache_idx=idx)
feat_cache[idx] = cache_x
feat_idx[0] += 1
Expand All @@ -213,7 +196,7 @@ def __init__(self, dim):
self.proj = ops.Conv2d(dim, dim, 1)
self.optimized_attention = vae_attention()

def forward(self, x):
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w')
Expand Down Expand Up @@ -283,17 +266,10 @@ def __init__(self,
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))

def forward(self, x, feat_cache=None, feat_idx=[0]):
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = x[:, :, -CACHE_T:, :, :]
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
Expand All @@ -303,29 +279,24 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
x = layer(x, feat_cache, feat_idx, final=final)
if x is None:
return None
else:
x = layer(x)

## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx, final=final)
else:
x = layer(x)

## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
Expand Down Expand Up @@ -393,14 +364,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = x[:, :, -CACHE_T:, :, :]
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
Expand All @@ -409,42 +373,56 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):

## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)

## upsamples
for layer in self.upsamples:
out_chunks = []

def run_up(layer_idx, x_ref, feat_idx):
x = x_ref[0]
x_ref[0] = None
if layer_idx >= len(self.upsamples):
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, feat_cache[feat_idx[0]])
feat_cache[feat_idx[0]] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
out_chunks.append(x)
return

layer = self.upsamples[layer_idx]
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
for frame_idx in range(x.shape[2]):
run_up(
layer_idx,
[x[:, :, frame_idx:frame_idx + 1, :, :]],
feat_idx.copy(),
)
del x
return

if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)

## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
next_x_ref = [x]
del x
run_up(layer_idx + 1, next_x_ref, feat_idx)

run_up(0, [x], feat_idx)
return out_chunks

def count_conv3d(model):

def count_cache_layers(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv3d):
if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'):
count += 1
return count

Expand Down Expand Up @@ -482,11 +460,12 @@ def encode(self, x):
conv_idx = [0]
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
t = 1 + ((t - 1) // 4) * 4
iter_ = 1 + (t - 1) // 2
feat_map = None
if iter_ > 1:
feat_map = [None] * count_conv3d(self.encoder)
## 对encode输入的x,按时间拆分为1、4、4、4....
feat_map = [None] * count_cache_layers(self.encoder)
## 对encode输入的x,按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整)
for i in range(iter_):
conv_idx = [0]
if i == 0:
Expand All @@ -496,20 +475,23 @@ def encode(self, x):
feat_idx=conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
feat_cache=feat_map,
feat_idx=conv_idx)
feat_idx=conv_idx,
final=(i == (iter_ - 1)))
if out_ is None:
continue
out = torch.cat([out, out_], 2)

mu, log_var = self.conv1(out).chunk(2, dim=1)
return mu

def decode(self, z):
conv_idx = [0]
# z: [b,c,t,h,w]
iter_ = z.shape[2]
iter_ = 1 + z.shape[2] // 2
feat_map = None
if iter_ > 1:
feat_map = [None] * count_conv3d(self.decoder)
feat_map = [None] * count_cache_layers(self.decoder)
x = self.conv2(z)
for i in range(iter_):
conv_idx = [0]
Expand All @@ -520,8 +502,8 @@ def decode(self, z):
feat_idx=conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
feat_cache=feat_map,
feat_idx=conv_idx)
out = torch.cat([out, out_], 2)
return out
out += out_
return torch.cat(out, 2)
2 changes: 1 addition & 1 deletion manager_requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
comfyui_manager==4.1b5
comfyui_manager==4.1b6
Loading