Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions comfy/ldm/lightricks/vae/causal_video_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,17 @@ def __init__(

self.gradient_checkpointing = False

# Precompute output scale factors: (channels, (t_scale, h_scale, w_scale), t_offset)
ts, hs, ws, to = 1, 1, 1, 0
for block in self.up_blocks:
if isinstance(block, DepthToSpaceUpsample):
ts *= block.stride[0]
hs *= block.stride[1]
ws *= block.stride[2]
if block.stride[0] > 1:
to = to * block.stride[0] + 1
self._output_scale = (out_channels // (patch_size ** 2), (ts, hs * patch_size, ws * patch_size), to)

self.timestep_conditioning = timestep_conditioning

if timestep_conditioning:
Expand All @@ -494,11 +505,15 @@ def __init__(
)


# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
def decode_output_shape(self, input_shape):
c, (ts, hs, ws), to = self._output_scale
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)

def forward_orig(
self,
sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None,
output_buffer: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
batch_size = sample.shape[0]
Expand Down Expand Up @@ -540,7 +555,13 @@ def forward_orig(
)
timestep_shift_scale = ada_values.unbind(dim=1)

output = []
if output_buffer is None:
output_buffer = torch.empty(
self.decode_output_shape(sample.shape),
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
)
output_offset = [0]

max_chunk_size = get_max_chunk_size(sample.device)

def run_up(idx, sample_ref, ended):
Expand All @@ -556,7 +577,10 @@ def run_up(idx, sample_ref, ended):
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
output.append(sample.to(comfy.model_management.intermediate_device()))
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
t = sample.shape[2]
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
output_offset[0] += t
return

up_block = self.up_blocks[idx]
Expand Down Expand Up @@ -588,11 +612,8 @@ def run_up(idx, sample_ref, ended):
run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)

run_up(0, [sample], True)
sample = torch.cat(output, dim=2)

sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)

return sample
return output_buffer

def forward(self, *args, **kwargs):
try:
Expand Down Expand Up @@ -1226,7 +1247,10 @@ def encode(self, x):
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)

def decode(self, x):
def decode_output_shape(self, input_shape):
return self.decoder.decode_output_shape(input_shape)

def decode(self, x, output_buffer=None):
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep, output_buffer=output_buffer)
19 changes: 15 additions & 4 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,12 +951,23 @@ def decode(self, samples_in, vae_options={}):
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)

# Pre-allocate output for VAEs that support direct buffer writes
preallocated = False
if hasattr(self.first_stage_model, 'decode_output_shape'):
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
preallocated = True

for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True))
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number] = out
if preallocated:
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
else:
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number].copy_(out)
del out
self.process_output(pixel_samples[x:x+batch_number])
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
Expand Down
8 changes: 4 additions & 4 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def encode_token_weights(self, token_weight_pairs):
out, pooled = o[:2]

if pooled is not None:
first_pooled = pooled[0:1].to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype())
first_pooled = pooled[0:1].to(device=model_management.intermediate_device())
else:
first_pooled = pooled

Expand All @@ -63,16 +63,16 @@ def encode_token_weights(self, token_weight_pairs):
output.append(z)

if (len(output) == 0):
r = (out[-1:].to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype()), first_pooled)
r = (out[-1:].to(device=model_management.intermediate_device()), first_pooled)
else:
r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype()), first_pooled)
r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device()), first_pooled)

if len(o) > 2:
extra = {}
for k in o[2]:
v = o[2][k]
if k == "attention_mask":
v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype())
v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device())
extra[k] = v

r = r + (extra,)
Expand Down
Loading