Skip to content

Commit 712efb4

Browse files
Add left padding to LTXAV text encoder. (Comfy-Org#12456)
1 parent 726af73 commit 712efb4

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

comfy/text_encoders/lt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def ltxv_te(*args, **kwargs):
2525
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
2626
def __init__(self, embedding_directory=None, tokenizer_data={}):
2727
tokenizer = tokenizer_data.get("spiece_model", None)
28-
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
28+
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
2929

3030
def state_dict(self):
3131
return {"spiece_model": self.tokenizer.serialize_model()}
@@ -97,6 +97,7 @@ def encode_token_weights(self, token_weight_pairs):
9797
token_weight_pairs = token_weight_pairs["gemma3_12b"]
9898

9999
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
100+
out = out[:, :, -torch.sum(extra["attention_mask"]).item():]
100101
out_device = out.device
101102
if comfy.model_management.should_use_bf16(self.execution_device):
102103
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
@@ -138,6 +139,7 @@ def memory_estimation_function(self, token_weight_pairs, device=None):
138139

139140
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
140141
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
142+
num_tokens = max(num_tokens, 64)
141143
return num_tokens * constant * 1024 * 1024
142144

143145
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):

0 commit comments

Comments
 (0)