11from __future__ import annotations
2+ import threading
23import torch
34from torch import nn
45from functools import partial
56import math
67from einops import rearrange
78from typing import List , Optional , Tuple , Union
89from .conv_nd_factory import make_conv_nd , make_linear_nd
10+ from .causal_conv3d import CausalConv3d
911from .pixel_norm import PixelNorm
1012from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
1113import comfy .ops
14+ from comfy .ldm .modules .diffusionmodules .model import torch_cat_if_needed
1215
1316ops = comfy .ops .disable_weight_init
1417
18+ def mark_conv3d_ended (module ):
19+ tid = threading .get_ident ()
20+ for _ , m in module .named_modules ():
21+ if isinstance (m , CausalConv3d ):
22+ current = m .temporal_cache_state .get (tid , (None , False ))
23+ m .temporal_cache_state [tid ] = (current [0 ], True )
24+
25+ def split2 (tensor , split_point , dim = 2 ):
26+ return torch .split (tensor , [split_point , tensor .shape [dim ] - split_point ], dim = dim )
27+
28+ def add_exchange_cache (dest , cache_in , new_input , dim = 2 ):
29+ if dest is not None :
30+ if cache_in is not None :
31+ cache_to_dest = min (dest .shape [dim ], cache_in .shape [dim ])
32+ lead_in_dest , dest = split2 (dest , cache_to_dest , dim = dim )
33+ lead_in_source , cache_in = split2 (cache_in , cache_to_dest , dim = dim )
34+ lead_in_dest .add_ (lead_in_source )
35+ body , new_input = split2 (new_input , dest .shape [dim ], dim )
36+ dest .add_ (body )
37+ return torch_cat_if_needed ([cache_in , new_input ], dim = dim )
38+
1539class Encoder (nn .Module ):
1640 r"""
1741 The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
@@ -205,7 +229,7 @@ def __init__(
205229
206230 self .gradient_checkpointing = False
207231
208- def forward (self , sample : torch .FloatTensor ) -> torch .FloatTensor :
232+ def forward_orig (self , sample : torch .FloatTensor ) -> torch .FloatTensor :
209233 r"""The forward method of the `Encoder` class."""
210234
211235 sample = patchify (sample , patch_size_hw = self .patch_size , patch_size_t = 1 )
@@ -254,6 +278,22 @@ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
254278
255279 return sample
256280
281+ def forward (self , * args , ** kwargs ):
282+ #No encoder support so just flag the end so it doesnt use the cache.
283+ mark_conv3d_ended (self )
284+ try :
285+ return self .forward_orig (* args , ** kwargs )
286+ finally :
287+ tid = threading .get_ident ()
288+ for _ , module in self .named_modules ():
289+ # ComfyUI doesn't thread this kind of stuff today, but just in case
290+ # we key on the thread to make it thread safe.
291+ tid = threading .get_ident ()
292+ if hasattr (module , "temporal_cache_state" ):
293+ module .temporal_cache_state .pop (tid , None )
294+
295+
296+ MAX_CHUNK_SIZE = (128 * 1024 ** 2 )
257297
258298class Decoder (nn .Module ):
259299 r"""
@@ -341,18 +381,6 @@ def __init__(
341381 timestep_conditioning = timestep_conditioning ,
342382 spatial_padding_mode = spatial_padding_mode ,
343383 )
344- elif block_name == "attn_res_x" :
345- block = UNetMidBlock3D (
346- dims = dims ,
347- in_channels = input_channel ,
348- num_layers = block_params ["num_layers" ],
349- resnet_groups = norm_num_groups ,
350- norm_layer = norm_layer ,
351- inject_noise = block_params .get ("inject_noise" , False ),
352- timestep_conditioning = timestep_conditioning ,
353- attention_head_dim = block_params ["attention_head_dim" ],
354- spatial_padding_mode = spatial_padding_mode ,
355- )
356384 elif block_name == "res_x_y" :
357385 output_channel = output_channel // block_params .get ("multiplier" , 2 )
358386 block = ResnetBlock3D (
@@ -428,15 +456,17 @@ def __init__(
428456 )
429457 self .last_scale_shift_table = nn .Parameter (torch .empty (2 , output_channel ))
430458
459+
431460 # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
432- def forward (
461+ def forward_orig (
433462 self ,
434463 sample : torch .FloatTensor ,
435464 timestep : Optional [torch .Tensor ] = None ,
436465 ) -> torch .FloatTensor :
437466 r"""The forward method of the `Decoder` class."""
438467 batch_size = sample .shape [0 ]
439468
469+ mark_conv3d_ended (self .conv_in )
440470 sample = self .conv_in (sample , causal = self .causal )
441471
442472 checkpoint_fn = (
@@ -445,24 +475,12 @@ def forward(
445475 else lambda x : x
446476 )
447477
448- scaled_timestep = None
478+ timestep_shift_scale = None
449479 if self .timestep_conditioning :
450480 assert (
451481 timestep is not None
452482 ), "should pass timestep with timestep_conditioning=True"
453483 scaled_timestep = timestep * self .timestep_scale_multiplier .to (dtype = sample .dtype , device = sample .device )
454-
455- for up_block in self .up_blocks :
456- if self .timestep_conditioning and isinstance (up_block , UNetMidBlock3D ):
457- sample = checkpoint_fn (up_block )(
458- sample , causal = self .causal , timestep = scaled_timestep
459- )
460- else :
461- sample = checkpoint_fn (up_block )(sample , causal = self .causal )
462-
463- sample = self .conv_norm_out (sample )
464-
465- if self .timestep_conditioning :
466484 embedded_timestep = self .last_time_embedder (
467485 timestep = scaled_timestep .flatten (),
468486 resolution = None ,
@@ -483,16 +501,62 @@ def forward(
483501 embedded_timestep .shape [- 2 ],
484502 embedded_timestep .shape [- 1 ],
485503 )
486- shift , scale = ada_values .unbind (dim = 1 )
487- sample = sample * (1 + scale ) + shift
504+ timestep_shift_scale = ada_values .unbind (dim = 1 )
505+
506+ output = []
507+
508+ def run_up (idx , sample , ended ):
509+ if idx >= len (self .up_blocks ):
510+ sample = self .conv_norm_out (sample )
511+ if timestep_shift_scale is not None :
512+ shift , scale = timestep_shift_scale
513+ sample = sample * (1 + scale ) + shift
514+ sample = self .conv_act (sample )
515+ if ended :
516+ mark_conv3d_ended (self .conv_out )
517+ sample = self .conv_out (sample , causal = self .causal )
518+ if sample is not None and sample .shape [2 ] > 0 :
519+ output .append (sample )
520+ return
521+
522+ up_block = self .up_blocks [idx ]
523+ if (ended ):
524+ mark_conv3d_ended (up_block )
525+ if self .timestep_conditioning and isinstance (up_block , UNetMidBlock3D ):
526+ sample = checkpoint_fn (up_block )(
527+ sample , causal = self .causal , timestep = scaled_timestep
528+ )
529+ else :
530+ sample = checkpoint_fn (up_block )(sample , causal = self .causal )
488531
489- sample = self .conv_act (sample )
490- sample = self .conv_out (sample , causal = self .causal )
532+ if sample is None or sample .shape [2 ] == 0 :
533+ return
534+
535+ total_bytes = sample .numel () * sample .element_size ()
536+ num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1 ) // MAX_CHUNK_SIZE
537+ samples = torch .chunk (sample , chunks = num_chunks , dim = 2 )
538+
539+ for chunk_idx , sample1 in enumerate (samples ):
540+ run_up (idx + 1 , sample1 , ended and chunk_idx == len (samples ) - 1 )
541+
542+ run_up (0 , sample , True )
543+ sample = torch .cat (output , dim = 2 )
491544
492545 sample = unpatchify (sample , patch_size_hw = self .patch_size , patch_size_t = 1 )
493546
494547 return sample
495548
549+ def forward (self , * args , ** kwargs ):
550+ try :
551+ return self .forward_orig (* args , ** kwargs )
552+ finally :
553+ for _ , module in self .named_modules ():
554+ #ComfyUI doesn't thread this kind of stuff today, but just incase
555+ #we key on the thread to make it thread safe.
556+ tid = threading .get_ident ()
557+ if hasattr (module , "temporal_cache_state" ):
558+ module .temporal_cache_state .pop (tid , None )
559+
496560
497561class UNetMidBlock3D (nn .Module ):
498562 """
@@ -663,8 +727,22 @@ def __init__(
663727 )
664728 self .residual = residual
665729 self .out_channels_reduction_factor = out_channels_reduction_factor
730+ self .temporal_cache_state = {}
666731
667732 def forward (self , x , causal : bool = True , timestep : Optional [torch .Tensor ] = None ):
733+ tid = threading .get_ident ()
734+ cached , drop_first_conv , drop_first_res = self .temporal_cache_state .get (tid , (None , True , True ))
735+ y = self .conv (x , causal = causal )
736+ y = rearrange (
737+ y ,
738+ "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)" ,
739+ p1 = self .stride [0 ],
740+ p2 = self .stride [1 ],
741+ p3 = self .stride [2 ],
742+ )
743+ if self .stride [0 ] == 2 and y .shape [2 ] > 0 and drop_first_conv :
744+ y = y [:, :, 1 :, :, :]
745+ drop_first_conv = False
668746 if self .residual :
669747 # Reshape and duplicate the input to match the output shape
670748 x_in = rearrange (
@@ -676,21 +754,20 @@ def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = Non
676754 )
677755 num_repeat = math .prod (self .stride ) // self .out_channels_reduction_factor
678756 x_in = x_in .repeat (1 , num_repeat , 1 , 1 , 1 )
679- if self .stride [0 ] == 2 :
757+ if self .stride [0 ] == 2 and x_in . shape [ 2 ] > 0 and drop_first_res :
680758 x_in = x_in [:, :, 1 :, :, :]
681- x = self .conv (x , causal = causal )
682- x = rearrange (
683- x ,
684- "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)" ,
685- p1 = self .stride [0 ],
686- p2 = self .stride [1 ],
687- p3 = self .stride [2 ],
688- )
689- if self .stride [0 ] == 2 :
690- x = x [:, :, 1 :, :, :]
691- if self .residual :
692- x = x + x_in
693- return x
759+ drop_first_res = False
760+
761+ if y .shape [2 ] == 0 :
762+ y = None
763+
764+ cached = add_exchange_cache (y , cached , x_in , dim = 2 )
765+ self .temporal_cache_state [tid ] = (cached , drop_first_conv , drop_first_res )
766+
767+ else :
768+ self .temporal_cache_state [tid ] = (None , drop_first_conv , False )
769+
770+ return y
694771
695772class LayerNorm (nn .Module ):
696773 def __init__ (self , dim , eps , elementwise_affine = True ) -> None :
@@ -807,6 +884,8 @@ def __init__(
807884 torch .randn (4 , in_channels ) / in_channels ** 0.5
808885 )
809886
887+ self .temporal_cache_state = {}
888+
810889 def _feed_spatial_noise (
811890 self , hidden_states : torch .FloatTensor , per_channel_scale : torch .FloatTensor
812891 ) -> torch .FloatTensor :
@@ -880,9 +959,12 @@ def forward(
880959
881960 input_tensor = self .conv_shortcut (input_tensor )
882961
883- output_tensor = input_tensor + hidden_states
962+ tid = threading .get_ident ()
963+ cached = self .temporal_cache_state .get (tid , None )
964+ cached = add_exchange_cache (hidden_states , cached , input_tensor , dim = 2 )
965+ self .temporal_cache_state [tid ] = cached
884966
885- return output_tensor
967+ return hidden_states
886968
887969
888970def patchify (x , patch_size_hw , patch_size_t = 1 ):
0 commit comments