@@ -25,7 +25,7 @@ def ltxv_te(*args, **kwargs):
2525class Gemma3_12BTokenizer (sd1_clip .SDTokenizer ):
2626 def __init__ (self , embedding_directory = None , tokenizer_data = {}):
2727 tokenizer = tokenizer_data .get ("spiece_model" , None )
28- super ().__init__ (tokenizer , pad_with_end = False , embedding_size = 3840 , embedding_key = 'gemma3_12b' , tokenizer_class = SPieceTokenizer , has_end_token = False , pad_to_max_length = False , max_length = 99999999 , min_length = 1 , disable_weights = True , tokenizer_args = {"add_bos" : True , "add_eos" : False }, tokenizer_data = tokenizer_data )
28+ super ().__init__ (tokenizer , pad_with_end = False , embedding_size = 3840 , embedding_key = 'gemma3_12b' , tokenizer_class = SPieceTokenizer , has_end_token = False , pad_to_max_length = False , max_length = 99999999 , min_length = 512 , pad_left = True , disable_weights = True , tokenizer_args = {"add_bos" : True , "add_eos" : False }, tokenizer_data = tokenizer_data )
2929
3030 def state_dict (self ):
3131 return {"spiece_model" : self .tokenizer .serialize_model ()}
@@ -97,6 +97,7 @@ def encode_token_weights(self, token_weight_pairs):
9797 token_weight_pairs = token_weight_pairs ["gemma3_12b" ]
9898
9999 out , pooled , extra = self .gemma3_12b .encode_token_weights (token_weight_pairs )
100+ out = out [:, :, - torch .sum (extra ["attention_mask" ]).item ():]
100101 out_device = out .device
101102 if comfy .model_management .should_use_bf16 (self .execution_device ):
102103 out = out .to (device = self .execution_device , dtype = torch .bfloat16 )
@@ -138,6 +139,7 @@ def memory_estimation_function(self, token_weight_pairs, device=None):
138139
139140 token_weight_pairs = token_weight_pairs .get ("gemma3_12b" , [])
140141 num_tokens = sum (map (lambda a : len (a ), token_weight_pairs ))
142+ num_tokens = max (num_tokens , 64 )
141143 return num_tokens * constant * 1024 * 1024
142144
143145def ltxav_te (dtype_llama = None , llama_quantization_metadata = None ):
0 commit comments