Skip to content

Commit 589228e

Browse files
authored
Add slice_cond and per-model context window cond resizing (Comfy-Org#12645)
* Add slice_cond and per-model context window cond resizing * Fix cond_value.size() call in context window cond resizing * Expose additional advanced inputs for ContextWindowsManualNode Necessary for WanAnimate context windows workflow, which needs cond_retain_index_list = 0 to work properly with its reference input. ---------
1 parent e4455fd commit 589228e

File tree

3 files changed

+87
-3
lines changed

3 files changed

+87
-3
lines changed

comfy/context_windows.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,50 @@ def init_callbacks(self):
9393
return {}
9494

9595

96+
def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, device, temporal_dim: int, temporal_scale: int=1, temporal_offset: int=0, retain_index_list: list[int]=[]):
97+
if not (hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor)):
98+
return None
99+
cond_tensor = cond_value.cond
100+
if temporal_dim >= cond_tensor.ndim:
101+
return None
102+
103+
cond_size = cond_tensor.size(temporal_dim)
104+
105+
if temporal_scale == 1:
106+
expected_size = x_in.size(window.dim) - temporal_offset
107+
if cond_size != expected_size:
108+
return None
109+
110+
if temporal_offset == 0 and temporal_scale == 1:
111+
sliced = window.get_tensor(cond_tensor, device, dim=temporal_dim, retain_index_list=retain_index_list)
112+
return cond_value._copy_with(sliced)
113+
114+
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
115+
if temporal_offset > 0:
116+
indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
117+
indices = [i for i in indices if 0 <= i]
118+
else:
119+
indices = list(window.index_list)
120+
121+
if not indices:
122+
return None
123+
124+
if temporal_scale > 1:
125+
scaled = []
126+
for i in indices:
127+
for k in range(temporal_scale):
128+
si = i * temporal_scale + k
129+
if si < cond_size:
130+
scaled.append(si)
131+
indices = scaled
132+
if not indices:
133+
return None
134+
135+
idx = tuple([slice(None)] * temporal_dim + [indices])
136+
sliced = cond_tensor[idx].to(device)
137+
return cond_value._copy_with(sliced)
138+
139+
96140
@dataclass
97141
class ContextSchedule:
98142
name: str
@@ -177,10 +221,17 @@ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: Inde
177221
new_cond_item[cond_key] = result
178222
handled = True
179223
break
224+
if not handled and self._model is not None:
225+
result = self._model.resize_cond_for_context_window(
226+
cond_key, cond_value, window, x_in, device,
227+
retain_index_list=self.cond_retain_index_list)
228+
if result is not None:
229+
new_cond_item[cond_key] = result
230+
handled = True
180231
if handled:
181232
continue
182233
if isinstance(cond_value, torch.Tensor):
183-
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
234+
if (self.dim < cond_value.ndim and cond_value.size(self.dim) == x_in.size(self.dim)) or \
184235
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
185236
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
186237
# Handle audio_embed (temporal dim is 1)
@@ -224,6 +275,7 @@ def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_option
224275
return context_windows
225276

226277
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
278+
self._model = model
227279
self.set_step(timestep, model_options)
228280
context_windows = self.get_context_windows(model, x_in, model_options)
229281
enumerated_context_windows = list(enumerate(context_windows))

comfy/model_base.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,12 @@ def concat_cond(self, **kwargs):
285285
return data
286286
return None
287287

288+
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
289+
"""Override in subclasses to handle model-specific cond slicing for context windows.
290+
Return a sliced cond object, or None to fall through to default handling.
291+
Use comfy.context_windows.slice_cond() for common cases."""
292+
return None
293+
288294
def extra_conds(self, **kwargs):
289295
out = {}
290296
concat_cond = self.concat_cond(**kwargs)
@@ -1375,6 +1381,12 @@ def extra_conds(self, **kwargs):
13751381
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
13761382
return out
13771383

1384+
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
1385+
if cond_key == "vace_context":
1386+
import comfy.context_windows
1387+
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list)
1388+
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
1389+
13781390
class WAN21_Camera(WAN21):
13791391
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
13801392
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
@@ -1427,6 +1439,12 @@ def extra_conds(self, **kwargs):
14271439

14281440
return out
14291441

1442+
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
1443+
if cond_key == "audio_embed":
1444+
import comfy.context_windows
1445+
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
1446+
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
1447+
14301448
class WAN22_Animate(WAN21):
14311449
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
14321450
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
@@ -1444,6 +1462,14 @@ def extra_conds(self, **kwargs):
14441462
out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
14451463
return out
14461464

1465+
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
1466+
import comfy.context_windows
1467+
if cond_key == "face_pixel_values":
1468+
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1)
1469+
if cond_key == "pose_latents":
1470+
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1)
1471+
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
1472+
14471473
class WAN22_S2V(WAN21):
14481474
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
14491475
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
@@ -1480,6 +1506,12 @@ def extra_conds_shapes(self, **kwargs):
14801506
out['reference_motion'] = reference_motion.shape
14811507
return out
14821508

1509+
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
1510+
if cond_key == "audio_embed":
1511+
import comfy.context_windows
1512+
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
1513+
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
1514+
14831515
class WAN22(WAN21):
14841516
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
14851517
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)

comfy_extras/nodes_context_windows.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ def define_schema(cls) -> io.Schema:
2727
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
2828
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
2929
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
30-
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
31-
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
30+
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
31+
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
3232
],
3333
outputs=[
3434
io.Model.Output(tooltip="The model with context windows applied during sampling."),

0 commit comments

Comments
 (0)