Skip to content

Commit d49420b

Browse files
authored
LongCat-Image edit (Comfy-Org#13003)
1 parent ebf6b52 commit d49420b

5 files changed

Lines changed: 36 additions & 10 deletions

File tree

comfy/ldm/flux/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None
386386
h = max(h, ref.shape[-2] + h_offset)
387387
w = max(w, ref.shape[-1] + w_offset)
388388

389-
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
389+
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, transformer_options=transformer_options)
390390
img = torch.cat([img, kontext], dim=1)
391391
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
392392
ref_num_tokens.append(kontext.shape[1])

comfy/model_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -937,9 +937,10 @@ def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, tran
937937
transformer_options = transformer_options.copy()
938938
rope_opts = transformer_options.get("rope_options", {})
939939
rope_opts = dict(rope_opts)
940+
pe_len = float(c_crossattn.shape[1]) if c_crossattn is not None else 512.0
940941
rope_opts.setdefault("shift_t", 1.0)
941-
rope_opts.setdefault("shift_y", 512.0)
942-
rope_opts.setdefault("shift_x", 512.0)
942+
rope_opts.setdefault("shift_y", pe_len)
943+
rope_opts.setdefault("shift_x", pe_len)
943944
transformer_options["rope_options"] = rope_opts
944945
return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
945946

comfy/text_encoders/llama.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,12 +1028,19 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
10281028
grid = e.get("extra", None)
10291029
start = e.get("index")
10301030
if position_ids is None:
1031-
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
1031+
position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long)
10321032
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
10331033
end = e.get("size") + start
10341034
len_max = int(grid.max()) // 2
10351035
start_next = len_max + start
1036-
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
1036+
if attention_mask is not None:
1037+
# Assign compact sequential positions to attended tokens only,
1038+
# skipping over padding so post-padding tokens aren't inflated.
1039+
after_mask = attention_mask[0, end:]
1040+
text_positions = after_mask.cumsum(0) - 1 + start_next + offset
1041+
position_ids[:, end:] = torch.where(after_mask.bool(), text_positions, position_ids[0, end:])
1042+
else:
1043+
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
10371044
position_ids[0, start:end] = start + offset
10381045
max_d = int(grid[0][1]) // 2
10391046
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]

comfy/text_encoders/longcat_image.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,22 @@ def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
6464
return [output]
6565

6666

67+
IMAGE_PAD_TOKEN_ID = 151655
68+
6769
class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
70+
T2I_PREFIX = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
71+
EDIT_PREFIX = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
72+
SUFFIX = "<|im_end|>\n<|im_start|>assistant\n"
73+
6874
def __init__(self, embedding_directory=None, tokenizer_data={}):
6975
super().__init__(
7076
embedding_directory=embedding_directory,
7177
tokenizer_data=tokenizer_data,
7278
name="qwen25_7b",
7379
tokenizer=LongCatImageBaseTokenizer,
7480
)
75-
self.longcat_template_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
76-
self.longcat_template_suffix = "<|im_end|>\n<|im_start|>assistant\n"
7781

78-
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
82+
def tokenize_with_weights(self, text, return_word_ids=False, images=None, **kwargs):
7983
skip_template = False
8084
if text.startswith("<|im_start|>"):
8185
skip_template = True
@@ -90,11 +94,14 @@ def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
9094
text, return_word_ids=return_word_ids, disable_weights=True, **kwargs
9195
)
9296
else:
97+
has_images = images is not None and len(images) > 0
98+
template_prefix = self.EDIT_PREFIX if has_images else self.T2I_PREFIX
99+
93100
prefix_ids = base_tok.tokenizer(
94-
self.longcat_template_prefix, add_special_tokens=False
101+
template_prefix, add_special_tokens=False
95102
)["input_ids"]
96103
suffix_ids = base_tok.tokenizer(
97-
self.longcat_template_suffix, add_special_tokens=False
104+
self.SUFFIX, add_special_tokens=False
98105
)["input_ids"]
99106

100107
prompt_tokens = base_tok.tokenize_with_weights(
@@ -106,6 +113,14 @@ def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
106113
suffix_pairs = [(t, 1.0) for t in suffix_ids]
107114

108115
combined = prefix_pairs + prompt_pairs + suffix_pairs
116+
117+
if has_images:
118+
embed_count = 0
119+
for i in range(len(combined)):
120+
if combined[i][0] == IMAGE_PAD_TOKEN_ID and embed_count < len(images):
121+
combined[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"}, combined[i][1])
122+
embed_count += 1
123+
109124
tokens = {"qwen25_7b": [combined]}
110125

111126
return tokens

comfy/text_encoders/qwen_vl.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,4 +425,7 @@ def forward(
425425
hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention)
426426

427427
hidden_states = self.merger(hidden_states)
428+
# Potentially important for spatially precise edits. This is present in the HF implementation.
429+
reverse_indices = torch.argsort(window_index)
430+
hidden_states = hidden_states[reverse_indices, :]
428431
return hidden_states

0 commit comments

Comments
 (0)