Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions app/assets/services/path_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,13 @@ def compute_relative_filename(file_path: str) -> str | None:

def get_asset_category_and_relative_path(
file_path: str,
) -> tuple[Literal["input", "output", "models"], str]:
) -> tuple[Literal["input", "output", "temp", "models"], str]:
"""Determine which root category a file path belongs to.
Categories:
- 'input': under folder_paths.get_input_directory()
- 'output': under folder_paths.get_output_directory()
- 'temp': under folder_paths.get_temp_directory()
- 'models': under any base path from get_comfy_models_folders()
Returns:
Expand Down Expand Up @@ -129,7 +130,12 @@ def _compute_relative(child: str, parent: str) -> str:
if _check_is_within(fp_abs, output_base):
return "output", _compute_relative(fp_abs, output_base)

# 3) models (check deepest matching base to avoid ambiguity)
# 3) temp
temp_base = os.path.abspath(folder_paths.get_temp_directory())
if _check_is_within(fp_abs, temp_base):
return "temp", _compute_relative(fp_abs, temp_base)

# 4) models (check deepest matching base to avoid ambiguity)
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
Expand All @@ -146,7 +152,7 @@ def _compute_relative(child: str, parent: str) -> str:
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)

raise ValueError(
f"Path is not within input, output, or configured model bases: {file_path}"
f"Path is not within input, output, temp, or configured model bases: {file_path}"
)


Expand Down
4 changes: 4 additions & 0 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
weight = state_dict.pop(weight_key, None)
if weight is None:
logging.warning(f"Missing weight for layer {layer_name}")
self.weight = None
return

manually_loaded_keys = [weight_key]
Expand Down Expand Up @@ -1034,6 +1035,9 @@ def state_dict(self, *args, destination=None, prefix="", **kwargs):
if self.bias is not None:
sd["{}bias".format(prefix)] = self.bias

if self.weight is None:
return sd

if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
Expand Down
33 changes: 28 additions & 5 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import comfy.text_encoders.anima
import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.text_encoders.qwen35

import comfy.model_patcher
import comfy.lora
Expand Down Expand Up @@ -425,13 +426,13 @@ def load_model(self, tokens={}):
def get_key_patches(self):
return self.patcher.get_key_patches()

def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
self.cond_stage_model.reset_clip_options()

self.load_model(tokens)
self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)

def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
Expand Down Expand Up @@ -1228,6 +1229,11 @@ class TEModel(Enum):
QWEN3_8B = 20
QWEN3_06B = 21
GEMMA_3_4B_VISION = 22
QWEN35_08B = 23
QWEN35_2B = 24
QWEN35_4B = 25
QWEN35_9B = 26
QWEN35_27B = 27


def detect_te_model(sd):
Expand Down Expand Up @@ -1267,6 +1273,17 @@ def detect_te_model(sd):
return TEModel.QWEN25_3B
if weight.shape[0] == 512:
return TEModel.QWEN25_7B
if "model.language_model.layers.0.linear_attn.A_log" in sd and "model.language_model.layers.0.input_layernorm.weight" in sd:
weight = sd['model.language_model.layers.0.input_layernorm.weight']
if weight.shape[0] == 1024:
return TEModel.QWEN35_08B
if weight.shape[0] == 2560:
return TEModel.QWEN35_4B
if weight.shape[0] == 4096:
return TEModel.QWEN35_9B
if weight.shape[0] == 5120:
return TEModel.QWEN35_27B
return TEModel.QWEN35_2B
if "model.layers.0.post_attention_layernorm.weight" in sd:
weight = sd['model.layers.0.post_attention_layernorm.weight']
if 'model.layers.0.self_attn.q_norm.weight' in sd:
Expand Down Expand Up @@ -1299,11 +1316,12 @@ def t5xxl_detect(clip_data):
return {}

def llama_detect(clip_data):
weight_name = "model.layers.0.self_attn.k_proj.weight"
weight_names = ["model.layers.0.self_attn.k_proj.weight", "model.layers.0.linear_attn.in_proj_a.weight"]

for sd in clip_data:
if weight_name in sd:
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
for weight_name in weight_names:
if weight_name in sd:
return comfy.text_encoders.hunyuan_video.llama_detect(sd)

return {}

Expand Down Expand Up @@ -1431,6 +1449,11 @@ class EmptyClass:
elif te_model == TEModel.JINA_CLIP_2:
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
elif te_model in (TEModel.QWEN35_08B, TEModel.QWEN35_2B, TEModel.QWEN35_4B, TEModel.QWEN35_9B, TEModel.QWEN35_27B):
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
qwen35_type = {TEModel.QWEN35_08B: "qwen35_08b", TEModel.QWEN35_2B: "qwen35_2b", TEModel.QWEN35_4B: "qwen35_4b", TEModel.QWEN35_9B: "qwen35_9b", TEModel.QWEN35_27B: "qwen35_27b"}[te_model]
clip_target.clip = comfy.text_encoders.qwen35.te(**llama_detect(clip_data), model_type=qwen35_type)
clip_target.tokenizer = comfy.text_encoders.qwen35.tokenizer(model_type=qwen35_type)
elif te_model == TEModel.QWEN3_06B:
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
Expand Down
8 changes: 4 additions & 4 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,14 @@ def encode(self, tokens):
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))

def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0):
if isinstance(tokens, dict):
tokens_only = next(iter(tokens.values())) # todo: get this better?
else:
tokens_only = tokens
tokens_only = [[t[0] for t in b] for b in tokens_only]
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=presence_penalty)

def parse_parentheses(string):
result = []
Expand Down Expand Up @@ -740,5 +740,5 @@ def encode_token_weights(self, token_weight_pairs):
def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd)

def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
50 changes: 34 additions & 16 deletions comfy/text_encoders/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ class Qwen3_8BConfig:
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False
lm_head: bool = True
stop_tokens = [151643, 151645]

@dataclass
Expand Down Expand Up @@ -655,6 +655,17 @@ def __init__(self, config, device=None, dtype=None, ops=None):
if config.lm_head:
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)

def get_past_len(self, past_key_values):
return past_key_values[0][2]

def compute_freqs_cis(self, position_ids, device):
return precompute_freqs_cis(self.config.head_dim,
position_ids,
self.config.rope_theta,
self.config.rope_scale,
self.config.rope_dims,
device=device)

def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
if embeds is not None:
x = embeds
Expand All @@ -667,17 +678,12 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
seq_len = x.shape[1]
past_len = 0
if past_key_values is not None and len(past_key_values) > 0:
past_len = past_key_values[0][2]
past_len = self.get_past_len(past_key_values)

if position_ids is None:
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)

freqs_cis = precompute_freqs_cis(self.config.head_dim,
position_ids,
self.config.rope_theta,
self.config.rope_scale,
self.config.rope_dims,
device=x.device)
freqs_cis = self.compute_freqs_cis(position_ids, x.device)

mask = None
if attention_mask is not None:
Expand Down Expand Up @@ -812,9 +818,16 @@ def logits(self, x):
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
return x

def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0):
device = embeds.device
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
model_config = self.model.config
past_key_values = []
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
return past_key_values

def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
device = embeds.device

if stop_tokens is None:
stop_tokens = self.model.config.stop_tokens
Expand All @@ -829,11 +842,8 @@ def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0,
if embeds.ndim == 2:
embeds = embeds.unsqueeze(0)

past_key_values = [] #kv_cache init
max_cache_len = embeds.shape[1] + max_length
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
past_key_values = self.init_kv_cache(embeds.shape[0], max_cache_len, device, execution_dtype)

generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None

Expand All @@ -844,7 +854,7 @@ def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0,
for step in tqdm(range(max_length), desc="Generating tokens"):
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
logits = self.logits(x)[:, -1]
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
token_id = next_token[0].item()
generated_token_ids.append(token_id)

Expand All @@ -856,7 +866,7 @@ def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0,

return generated_token_ids

def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True, presence_penalty=0.0):

if not do_sample or temperature == 0.0:
return torch.argmax(logits, dim=-1, keepdim=True)
Expand All @@ -867,6 +877,11 @@ def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_pena
for token_id in set(token_history):
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty

if presence_penalty is not None and presence_penalty != 0.0:
for i in range(logits.shape[0]):
for token_id in set(token_history):
logits[i, token_id] -= presence_penalty

if temperature != 1.0:
logits = logits / temperature

Expand Down Expand Up @@ -897,6 +912,9 @@ def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_pena
class BaseQwen3:
def logits(self, x):
input = x[:, -1:]
if self.model.config.lm_head:
return self.model.lm_head(input)

module = self.model.embed_tokens

offload_stream = None
Expand Down
Loading
Loading