Skip to content

Commit 0fd1b78

Browse files
authored
Reduce LTX2 VAE VRAM consumption (Comfy-Org#12028)
* causal_video_ae: Remove attention ResNet This attention_head_dim argument does not exist on this constructor so this is dead code. Remove as generic attention mid VAE conflicts with temporal roll. * ltx-vae: consoldate causal/non-causal code paths * ltx-vae: add cache rolling adder * ltx-vae: use cached adder for resnet * ltx-vae: Implement rolling VAE Implement a temporal rolling VAE for the LTX2 VAE. Usually when doing temporal rolling VAEs you can just chunk on time relying on causality and cache behind you as you go. The LTX VAE is however non-causal. So go whole hog and implement per layer run ahead and backpressure between the decoder layers using recursive state beween the layers. Operations are ammended with temporal_cache_state{} which they can use to hold any state then need for partial execution. Convolutions cache their inputs behind the up to N-1 frames, and skip connections need to cache the mismatch between convolution input and output that happens due to missing future (non-causal) input. Each call to run_up() processes a layer accross a range on input that may or may not be complete. It goes depth first to process as much as possible to try and digest frames to the final output ASAP. If layers run out of input due to convolution losses, they simply return without action effectively applying back-pressure to the earlier layers. As the earlier layers do more work and caller deeper, the partial states are reconciled and output continues to digest depth first as much as possible. Chunking is done using a size quota rather than a fixed frame length and any layer can initiate chunking, and multiple layers can chunk at different granulatiries. This remove the old limitation of always having to process 1 latent frame to entirety and having to hold 8 full decoded frames as the VRAM peak.
1 parent 8490eed commit 0fd1b78

3 files changed

Lines changed: 160 additions & 64 deletions

File tree

comfy/ldm/lightricks/vae/causal_conv3d.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from typing import Tuple, Union
22

3+
import threading
34
import torch
45
import torch.nn as nn
56
import comfy.ops
67
ops = comfy.ops.disable_weight_init
78

8-
99
class CausalConv3d(nn.Module):
1010
def __init__(
1111
self,
@@ -42,23 +42,34 @@ def __init__(
4242
padding_mode=spatial_padding_mode,
4343
groups=groups,
4444
)
45+
self.temporal_cache_state={}
4546

4647
def forward(self, x, causal: bool = True):
47-
if causal:
48-
first_frame_pad = x[:, :, :1, :, :].repeat(
49-
(1, 1, self.time_kernel_size - 1, 1, 1)
50-
)
51-
x = torch.concatenate((first_frame_pad, x), dim=2)
52-
else:
53-
first_frame_pad = x[:, :, :1, :, :].repeat(
54-
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
55-
)
56-
last_frame_pad = x[:, :, -1:, :, :].repeat(
57-
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
58-
)
59-
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
60-
x = self.conv(x)
61-
return x
48+
tid = threading.get_ident()
49+
50+
cached, is_end = self.temporal_cache_state.get(tid, (None, False))
51+
if cached is None:
52+
padding_length = self.time_kernel_size - 1
53+
if not causal:
54+
padding_length = padding_length // 2
55+
if x.shape[2] == 0:
56+
return x
57+
cached = x[:, :, :1, :, :].repeat((1, 1, padding_length, 1, 1))
58+
pieces = [ cached, x ]
59+
if is_end and not causal:
60+
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
61+
62+
needs_caching = not is_end
63+
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
64+
needs_caching = False
65+
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
66+
67+
x = torch.cat(pieces, dim=2)
68+
69+
if needs_caching:
70+
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
71+
72+
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
6273

6374
@property
6475
def weight(self):

comfy/ldm/lightricks/vae/causal_video_autoencoder.py

Lines changed: 129 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,41 @@
11
from __future__ import annotations
2+
import threading
23
import torch
34
from torch import nn
45
from functools import partial
56
import math
67
from einops import rearrange
78
from typing import List, Optional, Tuple, Union
89
from .conv_nd_factory import make_conv_nd, make_linear_nd
10+
from .causal_conv3d import CausalConv3d
911
from .pixel_norm import PixelNorm
1012
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
1113
import comfy.ops
14+
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
1215

1316
ops = comfy.ops.disable_weight_init
1417

18+
def mark_conv3d_ended(module):
19+
tid = threading.get_ident()
20+
for _, m in module.named_modules():
21+
if isinstance(m, CausalConv3d):
22+
current = m.temporal_cache_state.get(tid, (None, False))
23+
m.temporal_cache_state[tid] = (current[0], True)
24+
25+
def split2(tensor, split_point, dim=2):
26+
return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim)
27+
28+
def add_exchange_cache(dest, cache_in, new_input, dim=2):
29+
if dest is not None:
30+
if cache_in is not None:
31+
cache_to_dest = min(dest.shape[dim], cache_in.shape[dim])
32+
lead_in_dest, dest = split2(dest, cache_to_dest, dim=dim)
33+
lead_in_source, cache_in = split2(cache_in, cache_to_dest, dim=dim)
34+
lead_in_dest.add_(lead_in_source)
35+
body, new_input = split2(new_input, dest.shape[dim], dim)
36+
dest.add_(body)
37+
return torch_cat_if_needed([cache_in, new_input], dim=dim)
38+
1539
class Encoder(nn.Module):
1640
r"""
1741
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
@@ -205,7 +229,7 @@ def __init__(
205229

206230
self.gradient_checkpointing = False
207231

208-
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
232+
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
209233
r"""The forward method of the `Encoder` class."""
210234

211235
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
@@ -254,6 +278,22 @@ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
254278

255279
return sample
256280

281+
def forward(self, *args, **kwargs):
282+
#No encoder support so just flag the end so it doesnt use the cache.
283+
mark_conv3d_ended(self)
284+
try:
285+
return self.forward_orig(*args, **kwargs)
286+
finally:
287+
tid = threading.get_ident()
288+
for _, module in self.named_modules():
289+
# ComfyUI doesn't thread this kind of stuff today, but just in case
290+
# we key on the thread to make it thread safe.
291+
tid = threading.get_ident()
292+
if hasattr(module, "temporal_cache_state"):
293+
module.temporal_cache_state.pop(tid, None)
294+
295+
296+
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
257297

258298
class Decoder(nn.Module):
259299
r"""
@@ -341,18 +381,6 @@ def __init__(
341381
timestep_conditioning=timestep_conditioning,
342382
spatial_padding_mode=spatial_padding_mode,
343383
)
344-
elif block_name == "attn_res_x":
345-
block = UNetMidBlock3D(
346-
dims=dims,
347-
in_channels=input_channel,
348-
num_layers=block_params["num_layers"],
349-
resnet_groups=norm_num_groups,
350-
norm_layer=norm_layer,
351-
inject_noise=block_params.get("inject_noise", False),
352-
timestep_conditioning=timestep_conditioning,
353-
attention_head_dim=block_params["attention_head_dim"],
354-
spatial_padding_mode=spatial_padding_mode,
355-
)
356384
elif block_name == "res_x_y":
357385
output_channel = output_channel // block_params.get("multiplier", 2)
358386
block = ResnetBlock3D(
@@ -428,15 +456,17 @@ def __init__(
428456
)
429457
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
430458

459+
431460
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
432-
def forward(
461+
def forward_orig(
433462
self,
434463
sample: torch.FloatTensor,
435464
timestep: Optional[torch.Tensor] = None,
436465
) -> torch.FloatTensor:
437466
r"""The forward method of the `Decoder` class."""
438467
batch_size = sample.shape[0]
439468

469+
mark_conv3d_ended(self.conv_in)
440470
sample = self.conv_in(sample, causal=self.causal)
441471

442472
checkpoint_fn = (
@@ -445,24 +475,12 @@ def forward(
445475
else lambda x: x
446476
)
447477

448-
scaled_timestep = None
478+
timestep_shift_scale = None
449479
if self.timestep_conditioning:
450480
assert (
451481
timestep is not None
452482
), "should pass timestep with timestep_conditioning=True"
453483
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
454-
455-
for up_block in self.up_blocks:
456-
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
457-
sample = checkpoint_fn(up_block)(
458-
sample, causal=self.causal, timestep=scaled_timestep
459-
)
460-
else:
461-
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
462-
463-
sample = self.conv_norm_out(sample)
464-
465-
if self.timestep_conditioning:
466484
embedded_timestep = self.last_time_embedder(
467485
timestep=scaled_timestep.flatten(),
468486
resolution=None,
@@ -483,16 +501,62 @@ def forward(
483501
embedded_timestep.shape[-2],
484502
embedded_timestep.shape[-1],
485503
)
486-
shift, scale = ada_values.unbind(dim=1)
487-
sample = sample * (1 + scale) + shift
504+
timestep_shift_scale = ada_values.unbind(dim=1)
505+
506+
output = []
507+
508+
def run_up(idx, sample, ended):
509+
if idx >= len(self.up_blocks):
510+
sample = self.conv_norm_out(sample)
511+
if timestep_shift_scale is not None:
512+
shift, scale = timestep_shift_scale
513+
sample = sample * (1 + scale) + shift
514+
sample = self.conv_act(sample)
515+
if ended:
516+
mark_conv3d_ended(self.conv_out)
517+
sample = self.conv_out(sample, causal=self.causal)
518+
if sample is not None and sample.shape[2] > 0:
519+
output.append(sample)
520+
return
521+
522+
up_block = self.up_blocks[idx]
523+
if (ended):
524+
mark_conv3d_ended(up_block)
525+
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
526+
sample = checkpoint_fn(up_block)(
527+
sample, causal=self.causal, timestep=scaled_timestep
528+
)
529+
else:
530+
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
488531

489-
sample = self.conv_act(sample)
490-
sample = self.conv_out(sample, causal=self.causal)
532+
if sample is None or sample.shape[2] == 0:
533+
return
534+
535+
total_bytes = sample.numel() * sample.element_size()
536+
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
537+
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
538+
539+
for chunk_idx, sample1 in enumerate(samples):
540+
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
541+
542+
run_up(0, sample, True)
543+
sample = torch.cat(output, dim=2)
491544

492545
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
493546

494547
return sample
495548

549+
def forward(self, *args, **kwargs):
550+
try:
551+
return self.forward_orig(*args, **kwargs)
552+
finally:
553+
for _, module in self.named_modules():
554+
#ComfyUI doesn't thread this kind of stuff today, but just incase
555+
#we key on the thread to make it thread safe.
556+
tid = threading.get_ident()
557+
if hasattr(module, "temporal_cache_state"):
558+
module.temporal_cache_state.pop(tid, None)
559+
496560

497561
class UNetMidBlock3D(nn.Module):
498562
"""
@@ -663,8 +727,22 @@ def __init__(
663727
)
664728
self.residual = residual
665729
self.out_channels_reduction_factor = out_channels_reduction_factor
730+
self.temporal_cache_state = {}
666731

667732
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
733+
tid = threading.get_ident()
734+
cached, drop_first_conv, drop_first_res = self.temporal_cache_state.get(tid, (None, True, True))
735+
y = self.conv(x, causal=causal)
736+
y = rearrange(
737+
y,
738+
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
739+
p1=self.stride[0],
740+
p2=self.stride[1],
741+
p3=self.stride[2],
742+
)
743+
if self.stride[0] == 2 and y.shape[2] > 0 and drop_first_conv:
744+
y = y[:, :, 1:, :, :]
745+
drop_first_conv = False
668746
if self.residual:
669747
# Reshape and duplicate the input to match the output shape
670748
x_in = rearrange(
@@ -676,21 +754,20 @@ def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = Non
676754
)
677755
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
678756
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
679-
if self.stride[0] == 2:
757+
if self.stride[0] == 2 and x_in.shape[2] > 0 and drop_first_res:
680758
x_in = x_in[:, :, 1:, :, :]
681-
x = self.conv(x, causal=causal)
682-
x = rearrange(
683-
x,
684-
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
685-
p1=self.stride[0],
686-
p2=self.stride[1],
687-
p3=self.stride[2],
688-
)
689-
if self.stride[0] == 2:
690-
x = x[:, :, 1:, :, :]
691-
if self.residual:
692-
x = x + x_in
693-
return x
759+
drop_first_res = False
760+
761+
if y.shape[2] == 0:
762+
y = None
763+
764+
cached = add_exchange_cache(y, cached, x_in, dim=2)
765+
self.temporal_cache_state[tid] = (cached, drop_first_conv, drop_first_res)
766+
767+
else:
768+
self.temporal_cache_state[tid] = (None, drop_first_conv, False)
769+
770+
return y
694771

695772
class LayerNorm(nn.Module):
696773
def __init__(self, dim, eps, elementwise_affine=True) -> None:
@@ -807,6 +884,8 @@ def __init__(
807884
torch.randn(4, in_channels) / in_channels**0.5
808885
)
809886

887+
self.temporal_cache_state={}
888+
810889
def _feed_spatial_noise(
811890
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
812891
) -> torch.FloatTensor:
@@ -880,9 +959,12 @@ def forward(
880959

881960
input_tensor = self.conv_shortcut(input_tensor)
882961

883-
output_tensor = input_tensor + hidden_states
962+
tid = threading.get_ident()
963+
cached = self.temporal_cache_state.get(tid, None)
964+
cached = add_exchange_cache(hidden_states, cached, input_tensor, dim=2)
965+
self.temporal_cache_state[tid] = cached
884966

885-
return output_tensor
967+
return hidden_states
886968

887969

888970
def patchify(x, patch_size_hw, patch_size_t=1):

comfy/ldm/modules/diffusionmodules/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
import xformers.ops
1515

1616
def torch_cat_if_needed(xl, dim):
17+
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
1718
if len(xl) > 1:
1819
return torch.cat(xl, dim)
19-
else:
20+
elif len(xl) == 1:
2021
return xl[0]
22+
else:
23+
return None
2124

2225
def get_timestep_embedding(timesteps, embedding_dim):
2326
"""

0 commit comments

Comments
 (0)