Skip to content

Commit c012400

Browse files
Initial support for qwen image model. (Comfy-Org#9179)
1 parent 03895de commit c012400

File tree

8 files changed

+557
-4
lines changed

8 files changed

+557
-4
lines changed

comfy/ldm/qwen_image/model.py

Lines changed: 399 additions & 0 deletions
Large diffs are not rendered by default.

comfy/model_base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import comfy.ldm.chroma.model
4343
import comfy.ldm.ace.model
4444
import comfy.ldm.omnigen.omnigen2
45+
import comfy.ldm.qwen_image.model
4546

4647
import comfy.model_management
4748
import comfy.patcher_extension
@@ -1308,3 +1309,14 @@ def extra_conds_shapes(self, **kwargs):
13081309
if ref_latents is not None:
13091310
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
13101311
return out
1312+
1313+
class QwenImage(BaseModel):
1314+
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
1315+
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel)
1316+
1317+
def extra_conds(self, **kwargs):
1318+
out = super().extra_conds(**kwargs)
1319+
cross_attn = kwargs.get("cross_attn", None)
1320+
if cross_attn is not None:
1321+
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
1322+
return out

comfy/model_detection.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
481481
dit_config["timestep_scale"] = 1000.0
482482
return dit_config
483483

484+
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
485+
dit_config = {}
486+
dit_config["image_model"] = "qwen_image"
487+
return dit_config
488+
484489
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
485490
return None
486491

@@ -867,7 +872,7 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
867872
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
868873
hidden_size = state_dict["x_embedder.bias"].shape[0]
869874
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
870-
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
875+
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict and 'pos_embed.proj.weight' in state_dict: #SD3
871876
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
872877
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
873878
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)

comfy/sd.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import comfy.text_encoders.hidream
4848
import comfy.text_encoders.ace
4949
import comfy.text_encoders.omnigen2
50+
import comfy.text_encoders.qwen_image
5051

5152
import comfy.model_patcher
5253
import comfy.lora
@@ -771,6 +772,7 @@ class CLIPType(Enum):
771772
CHROMA = 15
772773
ACE = 16
773774
OMNIGEN2 = 17
775+
QWEN_IMAGE = 18
774776

775777

776778
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -791,6 +793,7 @@ class TEModel(Enum):
791793
T5_XXL_OLD = 8
792794
GEMMA_2_2B = 9
793795
QWEN25_3B = 10
796+
QWEN25_7B = 11
794797

795798
def detect_te_model(sd):
796799
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -812,7 +815,11 @@ def detect_te_model(sd):
812815
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
813816
return TEModel.GEMMA_2_2B
814817
if 'model.layers.0.self_attn.k_proj.bias' in sd:
815-
return TEModel.QWEN25_3B
818+
weight = sd['model.layers.0.self_attn.k_proj.bias']
819+
if weight.shape[0] == 256:
820+
return TEModel.QWEN25_3B
821+
if weight.shape[0] == 512:
822+
return TEModel.QWEN25_7B
816823
if "model.layers.0.post_attention_layernorm.weight" in sd:
817824
return TEModel.LLAMA3_8
818825
return None
@@ -917,6 +924,9 @@ class EmptyClass:
917924
elif te_model == TEModel.QWEN25_3B:
918925
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
919926
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
927+
elif te_model == TEModel.QWEN25_7B:
928+
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
929+
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
920930
else:
921931
# clip_l
922932
if clip_type == CLIPType.SD3:

comfy/supported_models.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import comfy.text_encoders.wan
2020
import comfy.text_encoders.ace
2121
import comfy.text_encoders.omnigen2
22+
import comfy.text_encoders.qwen_image
2223

2324
from . import supported_models_base
2425
from . import latent_formats
@@ -1229,7 +1230,36 @@ def clip_target(self, state_dict={}):
12291230
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
12301231
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
12311232

1233+
class QwenImage(supported_models_base.BASE):
1234+
unet_config = {
1235+
"image_model": "qwen_image",
1236+
}
1237+
1238+
sampling_settings = {
1239+
"multiplier": 1.0,
1240+
"shift": 2.6,
1241+
}
1242+
1243+
memory_usage_factor = 1.8 #TODO
1244+
1245+
unet_extra_config = {}
1246+
latent_format = latent_formats.Wan21
1247+
1248+
supported_inference_dtypes = [torch.bfloat16, torch.float32]
1249+
1250+
vae_key_prefix = ["vae."]
1251+
text_encoder_key_prefix = ["text_encoders."]
1252+
1253+
def get_model(self, state_dict, prefix="", device=None):
1254+
out = model_base.QwenImage(self, device=device)
1255+
return out
1256+
1257+
def clip_target(self, state_dict={}):
1258+
pref = self.text_encoder_key_prefix[0]
1259+
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
1260+
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
1261+
12321262

1233-
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
1263+
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
12341264

12351265
models += [SVD_img2vid]

comfy/text_encoders/llama.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,23 @@ class Qwen25_3BConfig:
4343
mlp_activation = "silu"
4444
qkv_bias = True
4545

46+
@dataclass
47+
class Qwen25_7BVLI_Config:
48+
vocab_size: int = 152064
49+
hidden_size: int = 3584
50+
intermediate_size: int = 18944
51+
num_hidden_layers: int = 28
52+
num_attention_heads: int = 28
53+
num_key_value_heads: int = 4
54+
max_position_embeddings: int = 128000
55+
rms_norm_eps: float = 1e-6
56+
rope_theta: float = 1000000.0
57+
transformer_type: str = "llama"
58+
head_dim = 128
59+
rms_norm_add = False
60+
mlp_activation = "silu"
61+
qkv_bias = True
62+
4663
@dataclass
4764
class Gemma2_2B_Config:
4865
vocab_size: int = 256000
@@ -348,6 +365,15 @@ def __init__(self, config_dict, dtype, device, operations):
348365
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
349366
self.dtype = dtype
350367

368+
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
369+
def __init__(self, config_dict, dtype, device, operations):
370+
super().__init__()
371+
config = Qwen25_7BVLI_Config(**config_dict)
372+
self.num_layers = config.num_hidden_layers
373+
374+
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
375+
self.dtype = dtype
376+
351377
class Gemma2_2B(BaseLlama, torch.nn.Module):
352378
def __init__(self, config_dict, dtype, device, operations):
353379
super().__init__()

comfy/text_encoders/qwen_image.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from transformers import Qwen2Tokenizer
2+
from comfy import sd1_clip
3+
import comfy.text_encoders.llama
4+
import os
5+
import torch
6+
import numbers
7+
8+
class Qwen25_7BVLITokenizer(sd1_clip.SDTokenizer):
9+
def __init__(self, embedding_directory=None, tokenizer_data={}):
10+
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
11+
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=3584, embedding_key='qwen25_7b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
12+
13+
14+
class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
15+
def __init__(self, embedding_directory=None, tokenizer_data={}):
16+
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer)
17+
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
18+
19+
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
20+
if llama_template is None:
21+
llama_text = self.llama_template.format(text)
22+
else:
23+
llama_text = llama_template.format(text)
24+
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
25+
26+
27+
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
28+
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
29+
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
30+
31+
32+
class QwenImageTEModel(sd1_clip.SD1ClipModel):
33+
def __init__(self, device="cpu", dtype=None, model_options={}):
34+
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
35+
36+
def encode_token_weights(self, token_weight_pairs):
37+
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
38+
tok_pairs = token_weight_pairs["qwen25_7b"][0]
39+
count_im_start = 0
40+
for i, v in enumerate(tok_pairs):
41+
elem = v[0]
42+
if not torch.is_tensor(elem):
43+
if isinstance(elem, numbers.Integral):
44+
if elem == 151644 and count_im_start < 2:
45+
template_end = i
46+
count_im_start += 1
47+
48+
if out.shape[1] > (template_end + 3):
49+
if tok_pairs[template_end + 1][0] == 872:
50+
if tok_pairs[template_end + 2][0] == 198:
51+
template_end += 3
52+
53+
out = out[:, template_end:]
54+
55+
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
56+
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
57+
extra.pop("attention_mask") # attention mask is useless if no masked elements
58+
59+
return out, pooled, extra
60+
61+
62+
def te(dtype_llama=None, llama_scaled_fp8=None):
63+
class QwenImageTEModel_(QwenImageTEModel):
64+
def __init__(self, device="cpu", dtype=None, model_options={}):
65+
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
66+
model_options = model_options.copy()
67+
model_options["scaled_fp8"] = llama_scaled_fp8
68+
if dtype_llama is not None:
69+
dtype = dtype_llama
70+
super().__init__(device=device, dtype=dtype, model_options=model_options)
71+
return QwenImageTEModel_

nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,7 @@ class CLIPLoader:
925925
@classmethod
926926
def INPUT_TYPES(s):
927927
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
928-
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2"], ),
928+
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image"], ),
929929
},
930930
"optional": {
931931
"device": (["default", "cpu"], {"advanced": True}),

0 commit comments

Comments
 (0)