Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions comfy/ldm/lightricks/vae/causal_conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def __init__(
self.in_channels = in_channels
self.out_channels = out_channels

if isinstance(stride, int):
self.time_stride = stride
else:
self.time_stride = stride[0]

kernel_size = (kernel_size, kernel_size, kernel_size)
self.time_kernel_size = kernel_size[0]

Expand Down Expand Up @@ -58,18 +63,23 @@ def forward(self, x, causal: bool = True):
pieces = [ cached, x ]
if is_end and not causal:
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
input_length = sum([piece.shape[2] for piece in pieces])
cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride)

needs_caching = not is_end
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
if needs_caching and cache_length == 0:
self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False)
needs_caching = False
if needs_caching and x.shape[2] >= cache_length:
needs_caching = False
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)

x = torch.cat(pieces, dim=2)
del pieces
del cached

if needs_caching:
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
elif is_end:
self.temporal_cache_state[tid] = (None, True)

Expand Down
91 changes: 71 additions & 20 deletions comfy/ldm/lightricks/vae/causal_video_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,7 @@ def __init__(

self.gradient_checkpointing = False

def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""

sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]:
sample = self.conv_in(sample)

checkpoint_fn = (
Expand All @@ -247,10 +244,14 @@ def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:

for down_block in self.down_blocks:
sample = checkpoint_fn(down_block)(sample)
if sample is None or sample.shape[2] == 0:
return None

sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if sample is None or sample.shape[2] == 0:
return None

if self.latent_log_var == "uniform":
last_channel = sample[:, -1:, ...]
Expand Down Expand Up @@ -282,9 +283,35 @@ def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:

return sample

def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""

max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder
frame_size = sample[:, :, :1, :, :].numel() * sample.element_size()
frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels))

outputs = []
samples = [sample[:, :, :1, :, :]]
if sample.shape[2] > 1:
chunk_t = max(2, max_chunk_size // frame_size)
if chunk_t < 4:
chunk_t = 2
elif chunk_t < 8:
chunk_t = 4
else:
chunk_t = (chunk_t // 8) * 8
samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2))
for chunk_idx, chunk in enumerate(samples):
if chunk_idx == len(samples) - 1:
mark_conv3d_ended(self)
chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device)
output = self._forward_chunk(chunk)
if output is not None:
outputs.append(output)

return torch_cat_if_needed(outputs, dim=2)

def forward(self, *args, **kwargs):
#No encoder support so just flag the end so it doesnt use the cache.
mark_conv3d_ended(self)
try:
return self.forward_orig(*args, **kwargs)
finally:
Expand Down Expand Up @@ -737,12 +764,25 @@ def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.temporal_cache_state = {}

def forward(self, x, causal: bool = True):
if self.stride[0] == 2:
tid = threading.get_ident()
cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None))
if cached_input is not None:
x = torch_cat_if_needed([cached_input, x], dim=2)
cached_input = None

if self.stride[0] == 2 and pad_first:
x = torch.cat(
[x[:, :, :1, :, :], x], dim=2
) # duplicate first frames for padding
pad_first = False

if x.shape[2] < self.stride[0]:
cached_input = x
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
return None

# skip connection
x_in = rearrange(
Expand All @@ -757,15 +797,26 @@ def forward(self, x, causal: bool = True):

# conv
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2 and x.shape[2] == 1:
if cached_x is not None:
x = torch_cat_if_needed([cached_x, x], dim=2)
cached_x = None
else:
cached_x = x
x = None

x = x + x_in
if x is not None:
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)

cached = add_exchange_cache(x, cached, x_in, dim=2)

self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)

return x

Expand Down Expand Up @@ -1098,6 +1149,8 @@ def normalize(self, x):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)

class VideoVAE(nn.Module):
comfy_has_chunked_io = True

def __init__(self, version=0, config=None):
super().__init__()

Expand Down Expand Up @@ -1240,11 +1293,9 @@ def get_default_config(self, version):
}
return config

def encode(self, x):
frames_count = x.shape[2]
if ((frames_count - 1) % 8) != 0:
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
def encode(self, x, device=None):
x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :]
means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1)
return self.per_channel_statistics.normalize(means)

def decode_output_shape(self, input_shape):
Expand Down
5 changes: 4 additions & 1 deletion comfy/memory_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def read_tensor_file_slice_into(tensor, destination):
if (destination.device.type != "cpu"
or file_obj is None
or threading.get_ident() != info.thread_id
or destination.numel() * destination.element_size() < info.size):
or destination.numel() * destination.element_size() < info.size
or tensor.numel() * tensor.element_size() != info.size
or tensor.storage_offset() != 0
or not tensor.is_contiguous()):
return False

if info.size == 0:
Expand Down
2 changes: 1 addition & 1 deletion comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def text_encoder_offload_device():
def text_encoder_device():
if args.gpu_only:
return get_torch_device()
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM, VRAMState.SHARED) or comfy.memory_management.aimdo_enabled:
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled:
if should_use_fp16(prioritize_performance=False):
return get_torch_device()
else:
Expand Down
11 changes: 8 additions & 3 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ def decode(self, samples_in, vae_options={}):

# Pre-allocate output for VAEs that support direct buffer writes
preallocated = False
if hasattr(self.first_stage_model, 'decode_output_shape'):
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
preallocated = True

Expand Down Expand Up @@ -1038,8 +1038,13 @@ def encode(self, pixel_samples):
batch_number = max(1, batch_number)
samples = None
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
out = self.first_stage_model.encode(pixels_in, device=self.device)
else:
pixels_in = pixels_in.to(self.device)
out = self.first_stage_model.encode(pixels_in)
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples[x:x + batch_number] = out
Expand Down
8 changes: 4 additions & 4 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,8 +1135,8 @@ def mult_list_upscale(a):
pbar.update(1)
continue

out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
out = output[b:b+1].zero_()
out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device)

positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]

Expand All @@ -1151,7 +1151,7 @@ def mult_list_upscale(a):
upscaled.append(round(get_pos(d, pos)))

ps = function(s_in).to(output_device)
mask = torch.ones_like(ps)
mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device)

for d in range(2, dims + 2):
feather = round(get_scale(d - 2, overlap[d - 2]))
Expand All @@ -1174,7 +1174,7 @@ def mult_list_upscale(a):
if pbar is not None:
pbar.update(1)

output[b:b+1] = out/out_div
out.div_(out_div)
return output

def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
Expand Down
Loading