Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion comfy/context_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,50 @@ def init_callbacks(self):
return {}


def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, device, temporal_dim: int, temporal_scale: int=1, temporal_offset: int=0, retain_index_list: list[int]=[]):
if not (hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor)):
return None
cond_tensor = cond_value.cond
if temporal_dim >= cond_tensor.ndim:
return None

cond_size = cond_tensor.size(temporal_dim)

if temporal_scale == 1:
expected_size = x_in.size(window.dim) - temporal_offset
if cond_size != expected_size:
return None

if temporal_offset == 0 and temporal_scale == 1:
sliced = window.get_tensor(cond_tensor, device, dim=temporal_dim, retain_index_list=retain_index_list)
return cond_value._copy_with(sliced)

# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
if temporal_offset > 0:
indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
indices = [i for i in indices if 0 <= i]
else:
indices = list(window.index_list)

if not indices:
return None

if temporal_scale > 1:
scaled = []
for i in indices:
for k in range(temporal_scale):
si = i * temporal_scale + k
if si < cond_size:
scaled.append(si)
indices = scaled
if not indices:
return None

idx = tuple([slice(None)] * temporal_dim + [indices])
sliced = cond_tensor[idx].to(device)
return cond_value._copy_with(sliced)


@dataclass
class ContextSchedule:
name: str
Expand Down Expand Up @@ -177,10 +221,17 @@ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: Inde
new_cond_item[cond_key] = result
handled = True
break
if not handled and self._model is not None:
result = self._model.resize_cond_for_context_window(
cond_key, cond_value, window, x_in, device,
retain_index_list=self.cond_retain_index_list)
if result is not None:
new_cond_item[cond_key] = result
handled = True
if handled:
continue
if isinstance(cond_value, torch.Tensor):
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
if (self.dim < cond_value.ndim and cond_value.size(self.dim) == x_in.size(self.dim)) or \
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
# Handle audio_embed (temporal dim is 1)
Expand Down Expand Up @@ -224,6 +275,7 @@ def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_option
return context_windows

def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
self._model = model
self.set_step(timestep, model_options)
context_windows = self.get_context_windows(model, x_in, model_options)
enumerated_context_windows = list(enumerate(context_windows))
Expand Down
97 changes: 49 additions & 48 deletions comfy/ldm/lightricks/vae/causal_video_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,53 @@ def decode_output_shape(self, input_shape):
c, (ts, hs, ws), to = self._output_scale
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)

def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size):
sample = sample_ref[0]
sample_ref[0] = None
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
t = sample.shape[2]
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
output_offset[0] += t
return

up_block = self.up_blocks[idx]
if ended:
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)

if sample is None or sample.shape[2] == 0:
return

total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size

if num_chunks == 1:
# when we are not chunking, detach our x so the callee can free it as soon as they are done
next_sample_ref = [sample]
del sample
self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
return
else:
samples = torch.chunk(sample, chunks=num_chunks, dim=2)

for chunk_idx, sample1 in enumerate(samples):
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)

def forward_orig(
self,
sample: torch.FloatTensor,
Expand All @@ -555,6 +602,7 @@ def forward_orig(
)

timestep_shift_scale = None
scaled_timestep = None
if self.timestep_conditioning:
assert (
timestep is not None
Expand Down Expand Up @@ -591,54 +639,7 @@ def forward_orig(

max_chunk_size = get_max_chunk_size(sample.device)

def run_up(idx, sample_ref, ended):
sample = sample_ref[0]
sample_ref[0] = None
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
t = sample.shape[2]
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
output_offset[0] += t
return

up_block = self.up_blocks[idx]
if (ended):
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)

if sample is None or sample.shape[2] == 0:
return

total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size

if num_chunks == 1:
# when we are not chunking, detach our x so the callee can free it as soon as they are done
next_sample_ref = [sample]
del sample
run_up(idx + 1, next_sample_ref, ended)
return
else:
samples = torch.chunk(sample, chunks=num_chunks, dim=2)

for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)

run_up(0, [sample], True)
self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)

return output_buffer

Expand Down
74 changes: 38 additions & 36 deletions comfy/ldm/wan/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,43 @@ def __init__(self,
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, output_channels, 3, padding=1))

def run_up(self, layer_idx, x_ref, feat_cache, feat_idx, out_chunks):
x = x_ref[0]
x_ref[0] = None
if layer_idx >= len(self.upsamples):
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, feat_cache[feat_idx[0]])
feat_cache[feat_idx[0]] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
out_chunks.append(x)
return

layer = self.upsamples[layer_idx]
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
for frame_idx in range(x.shape[2]):
self.run_up(
layer_idx,
[x[:, :, frame_idx:frame_idx + 1, :, :]],
feat_cache,
feat_idx.copy(),
out_chunks,
)
del x
return

if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)

next_x_ref = [x]
del x
self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks)

def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
Expand All @@ -380,42 +417,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):

out_chunks = []

def run_up(layer_idx, x_ref, feat_idx):
x = x_ref[0]
x_ref[0] = None
if layer_idx >= len(self.upsamples):
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, feat_cache[feat_idx[0]])
feat_cache[feat_idx[0]] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
out_chunks.append(x)
return

layer = self.upsamples[layer_idx]
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
for frame_idx in range(x.shape[2]):
run_up(
layer_idx,
[x[:, :, frame_idx:frame_idx + 1, :, :]],
feat_idx.copy(),
)
del x
return

if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)

next_x_ref = [x]
del x
run_up(layer_idx + 1, next_x_ref, feat_idx)

run_up(0, [x], feat_idx)
self.run_up(0, [x], feat_cache, feat_idx, out_chunks)
return out_chunks


Expand Down
32 changes: 32 additions & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,12 @@ def concat_cond(self, **kwargs):
return data
return None

def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
"""Override in subclasses to handle model-specific cond slicing for context windows.
Return a sliced cond object, or None to fall through to default handling.
Use comfy.context_windows.slice_cond() for common cases."""
return None

def extra_conds(self, **kwargs):
out = {}
concat_cond = self.concat_cond(**kwargs)
Expand Down Expand Up @@ -1375,6 +1381,12 @@ def extra_conds(self, **kwargs):
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
return out

def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "vace_context":
import comfy.context_windows
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)

class WAN21_Camera(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
Expand Down Expand Up @@ -1427,6 +1439,12 @@ def extra_conds(self, **kwargs):

return out

def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "audio_embed":
import comfy.context_windows
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)

class WAN22_Animate(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
Expand All @@ -1444,6 +1462,14 @@ def extra_conds(self, **kwargs):
out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
return out

def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
import comfy.context_windows
if cond_key == "face_pixel_values":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1)
if cond_key == "pose_latents":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)

class WAN22_S2V(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
Expand Down Expand Up @@ -1480,6 +1506,12 @@ def extra_conds_shapes(self, **kwargs):
out['reference_motion'] = reference_motion.shape
return out

def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "audio_embed":
import comfy.context_windows
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)

class WAN22(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
Expand Down
2 changes: 2 additions & 0 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,7 @@ def decode(self, samples_in, vae_options={}):
do_tile = True

if do_tile:
comfy.model_management.soft_empty_cache()
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
Expand Down Expand Up @@ -1059,6 +1060,7 @@ def encode(self, pixel_samples):
do_tile = True

if do_tile:
comfy.model_management.soft_empty_cache()
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
Expand Down
Loading
Loading