Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 65 additions & 11 deletions comfy/ldm/flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,22 @@ class FluxParams:
txt_norm: bool = False


def invert_slices(slices, length):
sorted_slices = sorted(slices)
result = []
current = 0

for start, end in sorted_slices:
if current < start:
result.append((current, start))
current = max(current, end)

if current < length:
result.append((current, length))

return result


class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
Expand Down Expand Up @@ -138,6 +154,7 @@ def forward_orig(
y: Tensor,
guidance: Tensor = None,
control = None,
timestep_zero_index=None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
Expand All @@ -164,10 +181,6 @@ def forward_orig(
txt = self.txt_norm(txt)
txt = self.txt_in(txt)

vec_orig = vec
if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))

if "post_input" in patches:
for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
Expand All @@ -182,6 +195,24 @@ def forward_orig(
else:
pe = None

vec_orig = vec
txt_vec = vec
extra_kwargs = {}
if timestep_zero_index is not None:
modulation_dims = []
batch = vec.shape[0] // 2
vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1)
invert = invert_slices(timestep_zero_index, img.shape[1])
for s in invert:
modulation_dims.append((s[0], s[1], 0))
for s in timestep_zero_index:
modulation_dims.append((s[0], s[1], 1))
extra_kwargs["modulation_dims_img"] = modulation_dims
txt_vec = vec[:batch]

if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec))

blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
Expand All @@ -195,7 +226,8 @@ def block_wrap(args):
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
transformer_options=args.get("transformer_options"),
**extra_kwargs)
return out

out = blocks_replace[("double_block", i)]({"img": img,
Expand All @@ -213,7 +245,8 @@ def block_wrap(args):
vec=vec,
pe=pe,
attn_mask=attn_mask,
transformer_options=transformer_options)
transformer_options=transformer_options,
**extra_kwargs)

if control is not None: # Controlnet
control_i = control.get("input")
Expand All @@ -230,6 +263,12 @@ def block_wrap(args):
if self.params.global_modulation:
vec, _ = self.single_stream_modulation(vec_orig)

extra_kwargs = {}
if timestep_zero_index is not None:
lambda a: 0 if a == 0 else a + txt.shape[1]
modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims))
extra_kwargs["modulation_dims"] = modulation_dims_combined

transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
Expand All @@ -242,7 +281,8 @@ def block_wrap(args):
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
transformer_options=args.get("transformer_options"),
**extra_kwargs)
return out

out = blocks_replace[("single_block", i)]({"img": img,
Expand All @@ -253,7 +293,7 @@ def block_wrap(args):
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs)

if control is not None: # Controlnet
control_o = control.get("output")
Expand All @@ -264,7 +304,11 @@ def block_wrap(args):

img = img[:, txt.shape[1] :, ...]

img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
extra_kwargs = {}
if timestep_zero_index is not None:
extra_kwargs["modulation_dims"] = modulation_dims

img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels)
return img

def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
Expand Down Expand Up @@ -312,13 +356,16 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None
w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x, transformer_options=transformer_options)
img_tokens = img.shape[1]
timestep_zero_index = None
if ref_latents is not None:
ref_num_tokens = []
h = 0
w = 0
index = 0
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
timestep_zero = ref_latents_method == "index_timestep_zero"
for ref in ref_latents:
if ref_latents_method == "index":
if ref_latents_method in ("index", "index_timestep_zero"):
index += self.params.ref_index_scale
h_offset = 0
w_offset = 0
Expand All @@ -342,13 +389,20 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
ref_num_tokens.append(kontext.shape[1])
if timestep_zero:
if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = [[img_tokens, img_ids.shape[1]]]
transformer_options = transformer_options.copy()
transformer_options["reference_image_num_tokens"] = ref_num_tokens

txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)

if len(self.params.txt_ids_dims) > 0:
for i in self.params.txt_ids_dims:
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)

out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]
64 changes: 64 additions & 0 deletions comfy_extras/nodes_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import math
import nodes
import comfy.ldm.flux.math

class CLIPTextEncodeFlux(io.ComfyNode):
@classmethod
Expand Down Expand Up @@ -231,6 +232,68 @@ def execute(cls, steps, width, height) -> io.NodeOutput:
sigmas = get_schedule(steps, round(seq_len))
return io.NodeOutput(sigmas)

class KV_Attn_Input:
def __init__(self):
self.cache = {}

def __call__(self, q, k, v, extra_options, **kwargs):
reference_image_num_tokens = extra_options.get("reference_image_num_tokens", [])
if len(reference_image_num_tokens) == 0:
return {}

ref_toks = sum(reference_image_num_tokens)
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
if cache_key in self.cache:
kk, vv = self.cache[cache_key]
self.set_cache = False
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}

self.cache[cache_key] = (k[:, :, -ref_toks:], v[:, :, -ref_toks:])
self.set_cache = True
return {"q": q, "k": k, "v": v}

def cleanup(self):
self.cache = {}


class FluxKVCache(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="FluxKVCache",
display_name="Flux KV Cache",
description="Enables KV Cache optimization for reference images on Flux family models.",
category="",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to use KV Cache on."),
],
outputs=[
io.Model.Output(tooltip="The patched model with KV Cache enabled."),
],
)

@classmethod
def execute(cls, model: io.Model.Type) -> io.NodeOutput:
m = model.clone()
input_patch_obj = KV_Attn_Input()

def model_input_patch(inputs):
if len(input_patch_obj.cache) > 0:
ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", []))
if ref_image_tokens > 0:
img = inputs["img"]
inputs["img"] = img[:, :-ref_image_tokens]
return inputs

m.set_model_attn1_patch(input_patch_obj)
m.set_model_post_input_patch(model_input_patch)
if hasattr(model.model.diffusion_model, "params"):
m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero")
else:
m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero")

return io.NodeOutput(m)

class FluxExtension(ComfyExtension):
@override
Expand All @@ -243,6 +306,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]:
FluxKontextMultiReferenceLatentMethod,
EmptyFlux2LatentImage,
Flux2Scheduler,
FluxKVCache,
]


Expand Down
Loading