Skip to content

Commit 79d17ba

Browse files
kijaiKosinkadinkdrozbay
authored
Context windows fixes and features (Comfy-Org#10975)
* Apply cond slice fix * Add FreeNoise * Update context_windows.py * Add option to retain condition by indexes for each window This allows for example Wan/HunyuanVideo image to video to "work" by using the initial start frame for each window, otherwise windows beyond first will be pure T2V generations. * Update context_windows.py * Allow splitting multiple conds into different windows * Add handling for audio_embed * whitespace * Allow freenoise to work on other dims, handle 4D batch timestep Refactor Freenoise function. And fix batch handling as timesteps seem to be expanded to batch size now. * Disable experimental options for now So that the Freenoise and bugfixes can be merged first --------- Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com> Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com>
1 parent 6fd463a commit 79d17ba

File tree

2 files changed

+108
-18
lines changed

2 files changed

+108
-18
lines changed

comfy/context_windows.py

Lines changed: 90 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,26 +51,36 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
5151

5252

5353
class IndexListContextWindow(ContextWindowABC):
54-
def __init__(self, index_list: list[int], dim: int=0):
54+
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
5555
self.index_list = index_list
5656
self.context_length = len(index_list)
5757
self.dim = dim
58+
self.total_frames = total_frames
59+
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
5860

59-
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
61+
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
6062
if dim is None:
6163
dim = self.dim
6264
if dim == 0 and full.shape[dim] == 1:
6365
return full
64-
idx = [slice(None)] * dim + [self.index_list]
65-
return full[idx].to(device)
66+
idx = tuple([slice(None)] * dim + [self.index_list])
67+
window = full[idx]
68+
if retain_index_list:
69+
idx = tuple([slice(None)] * dim + [retain_index_list])
70+
window[idx] = full[idx]
71+
return window.to(device)
6672

6773
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
6874
if dim is None:
6975
dim = self.dim
70-
idx = [slice(None)] * dim + [self.index_list]
76+
idx = tuple([slice(None)] * dim + [self.index_list])
7177
full[idx] += to_add
7278
return full
7379

80+
def get_region_index(self, num_regions: int) -> int:
81+
region_idx = int(self.center_ratio * num_regions)
82+
return min(max(region_idx, 0), num_regions - 1)
83+
7484

7585
class IndexListCallbacks:
7686
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
@@ -94,7 +104,8 @@ class ContextFuseMethod:
94104

95105
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
96106
class IndexListContextHandler(ContextHandlerABC):
97-
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
107+
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
108+
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
98109
self.context_schedule = context_schedule
99110
self.fuse_method = fuse_method
100111
self.context_length = context_length
@@ -103,13 +114,18 @@ def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMe
103114
self.closed_loop = closed_loop
104115
self.dim = dim
105116
self._step = 0
117+
self.freenoise = freenoise
118+
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
119+
self.split_conds_to_windows = split_conds_to_windows
106120

107121
self.callbacks = {}
108122

109123
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
110124
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
111125
if x_in.size(self.dim) > self.context_length:
112-
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
126+
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
127+
if self.cond_retain_index_list:
128+
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
113129
return True
114130
return False
115131

@@ -123,6 +139,11 @@ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: Inde
123139
return None
124140
# reuse or resize cond items to match context requirements
125141
resized_cond = []
142+
# if multiple conds, split based on primary region
143+
if self.split_conds_to_windows and len(cond_in) > 1:
144+
region = window.get_region_index(len(cond_in))
145+
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
146+
cond_in = [cond_in[region]]
126147
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
127148
for actual_cond in cond_in:
128149
resized_actual_cond = actual_cond.copy()
@@ -146,12 +167,19 @@ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: Inde
146167
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
147168
for cond_key, cond_value in new_cond_item.items():
148169
if isinstance(cond_value, torch.Tensor):
149-
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
170+
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
171+
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
150172
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
173+
# Handle audio_embed (temporal dim is 1)
174+
elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
175+
audio_cond = cond_value.cond
176+
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
177+
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
151178
# if has cond that is a Tensor, check if needs to be subset
152179
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
153-
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
154-
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
180+
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
181+
(cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
182+
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
155183
elif cond_key == "num_video_frames": # for SVD
156184
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
157185
new_cond_item[cond_key].cond = window.context_length
@@ -164,7 +192,7 @@ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: Inde
164192
return resized_cond
165193

166194
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
167-
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
195+
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
168196
matches = torch.nonzero(mask)
169197
if torch.numel(matches) == 0:
170198
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
@@ -173,7 +201,7 @@ def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
173201
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
174202
full_length = x_in.size(self.dim) # TODO: choose dim based on model
175203
context_windows = self.context_schedule.func(full_length, self, model_options)
176-
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
204+
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
177205
return context_windows
178206

179207
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
@@ -250,8 +278,8 @@ def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_
250278
prev_weight = (bias_total / (bias_total + bias))
251279
new_weight = (bias / (bias_total + bias))
252280
# account for dims of tensors
253-
idx_window = [slice(None)] * self.dim + [idx]
254-
pos_window = [slice(None)] * self.dim + [pos]
281+
idx_window = tuple([slice(None)] * self.dim + [idx])
282+
pos_window = tuple([slice(None)] * self.dim + [pos])
255283
# apply new values
256284
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
257285
biases_final[i][idx] = bias_total + bias
@@ -287,6 +315,28 @@ def create_prepare_sampling_wrapper(model: ModelPatcher):
287315
)
288316

289317

318+
def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
319+
model_options = extra_args.get("model_options", None)
320+
if model_options is None:
321+
raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
322+
handler: IndexListContextHandler = model_options.get("context_handler", None)
323+
if handler is None:
324+
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
325+
if not handler.freenoise:
326+
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
327+
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
328+
329+
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
330+
331+
332+
def create_sampler_sample_wrapper(model: ModelPatcher):
333+
model.add_wrapper_with_key(
334+
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
335+
"ContextWindows_sampler_sample",
336+
_sampler_sample_wrapper
337+
)
338+
339+
290340
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
291341
total_dims = len(x_in.shape)
292342
weights_tensor = torch.Tensor(weights).to(device=device)
@@ -538,3 +588,29 @@ def shift_window_to_end(window: list[int], num_frames: int):
538588
for i in range(len(window)):
539589
# 2) add end_delta to each val to slide windows to end
540590
window[i] = window[i] + end_delta
591+
592+
593+
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
594+
def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
595+
logging.info("Context windows: Applying FreeNoise")
596+
generator = torch.Generator(device='cpu').manual_seed(seed)
597+
latent_video_length = noise.shape[dim]
598+
delta = context_length - context_overlap
599+
600+
for start_idx in range(0, latent_video_length - context_length, delta):
601+
place_idx = start_idx + context_length
602+
603+
actual_delta = min(delta, latent_video_length - place_idx)
604+
if actual_delta <= 0:
605+
break
606+
607+
list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
608+
609+
source_slice = [slice(None)] * noise.ndim
610+
source_slice[dim] = list_idx
611+
target_slice = [slice(None)] * noise.ndim
612+
target_slice[dim] = slice(place_idx, place_idx + actual_delta)
613+
614+
noise[tuple(target_slice)] = noise[tuple(source_slice)]
615+
616+
return noise

comfy_extras/nodes_context_windows.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def define_schema(cls) -> io.Schema:
2626
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
2727
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
2828
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
29+
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
30+
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
31+
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
2932
],
3033
outputs=[
3134
io.Model.Output(tooltip="The model with context windows applied during sampling."),
@@ -34,7 +37,8 @@ def define_schema(cls) -> io.Schema:
3437
)
3538

3639
@classmethod
37-
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model:
40+
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
41+
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
3842
model = model.clone()
3943
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
4044
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
@@ -43,9 +47,15 @@ def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int
4347
context_overlap=context_overlap,
4448
context_stride=context_stride,
4549
closed_loop=closed_loop,
46-
dim=dim)
50+
dim=dim,
51+
freenoise=freenoise,
52+
cond_retain_index_list=cond_retain_index_list,
53+
split_conds_to_windows=split_conds_to_windows
54+
)
4755
# make memory usage calculation only take into account the context window latents
4856
comfy.context_windows.create_prepare_sampling_wrapper(model)
57+
if freenoise: # no other use for this wrapper at this time
58+
comfy.context_windows.create_sampler_sample_wrapper(model)
4959
return io.NodeOutput(model)
5060

5161
class WanContextWindowsManualNode(ContextWindowsManualNode):
@@ -68,14 +78,18 @@ def define_schema(cls) -> io.Schema:
6878
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
6979
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
7080
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
81+
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
82+
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
83+
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
7184
]
7285
return schema
7386

7487
@classmethod
75-
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model:
88+
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
89+
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
7690
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
7791
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
78-
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2)
92+
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
7993

8094

8195
class ContextWindowsExtension(ComfyExtension):

0 commit comments

Comments
 (0)