Skip to content

Commit 4a6c311

Browse files
committed
refactor(internals): align model metadata wrappers with llama.cpp API
- Use `llama_vocab_n_tokens()` instead of the old vocab size helper. - Add Python wrappers for model description, size, chat template, and trained RoPE frequency scaling. - Clarify model capability helpers with docstrings matching llama.cpp semantics. - Rename `desc()` and `size()` to `model_desc()` and `model_size()` to make their scope explicit. - Drop the unused `get_tensor()` stub since llama.cpp does not expose it. - Route rerank template lookup through `LlamaModel.model_chat_template()` for consistency with the internal model abstraction. Signed-off-by: JamePeng <jame_peng@sina.com>
1 parent 40f0536 commit 4a6c311

3 files changed

Lines changed: 59 additions & 19 deletions

File tree

llama_cpp/_internals.py

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def vocab_type(self) -> int:
102102
return llama_cpp.llama_vocab_type(self.model)
103103

104104
def n_vocab(self) -> int:
105-
return llama_cpp.llama_n_vocab(self.vocab)
105+
return llama_cpp.llama_vocab_n_tokens(self.vocab)
106106

107107
def n_ctx_train(self) -> int:
108108
return llama_cpp.llama_model_n_ctx_train(self.model)
@@ -131,41 +131,76 @@ def n_head_kv(self) -> int:
131131
def n_swa(self) -> int:
132132
return llama_cpp.llama_model_n_swa(self.model)
133133

134+
def rope_freq_scale_train(self) -> float:
135+
"""
136+
Get the model's RoPE frequency scaling factor
137+
"""
138+
return llama_cpp.llama_model_rope_freq_scale_train(self.model)
139+
140+
def model_desc(self) -> str:
141+
"""
142+
Get a string describing the model type
143+
"""
144+
buf = ctypes.create_string_buffer(256)
145+
llama_cpp.llama_model_desc(self.model, buf, 256)
146+
return buf.value.decode("utf-8")
147+
148+
def model_size(self) -> int:
149+
"""
150+
Returns the total size of all the tensors in the model in bytes
151+
"""
152+
return llama_cpp.llama_model_size(self.model)
153+
154+
def model_chat_template(self, name: bytes) -> str:
155+
"""
156+
Get the default chat template. Returns nullptr if not available
157+
If name is NULL, returns the default chat template
158+
"""
159+
return llama_cpp.llama_model_chat_template(self.model, name).decode("utf-8")
160+
134161
def n_params(self) -> int:
162+
"""
163+
Returns the total number of parameters in the model
164+
"""
135165
return llama_cpp.llama_model_n_params(self.model)
136166

137167
def has_encoder(self) -> bool:
168+
"""
169+
Returns true if the model contains an encoder that requires llama_encode() call
170+
"""
138171
return llama_cpp.llama_model_has_encoder(self.model)
139172

140173
def has_decoder(self) -> bool:
174+
"""
175+
Returns true if the model contains a decoder that requires llama_decode() call
176+
"""
141177
return llama_cpp.llama_model_has_decoder(self.model)
142178

143179
def decoder_start_token(self) -> int:
180+
"""
181+
For encoder-decoder models, this function returns id of the token that must be provided
182+
to the decoder to start generating output sequence. For other models, it returns -1.
183+
"""
144184
return llama_cpp.llama_model_decoder_start_token(self.model)
145185

146186
def is_recurrent(self) -> bool:
187+
"""
188+
Returns true if the model is recurrent (like Mamba, RWKV, etc.)
189+
"""
147190
return llama_cpp.llama_model_is_recurrent(self.model)
148191

149192
def is_hybrid(self) -> bool:
193+
"""
194+
Returns true if the model is hybrid (like Jamba, Granite, etc.)
195+
"""
150196
return llama_cpp.llama_model_is_hybrid(self.model)
151197

152198
def is_diffusion(self) -> bool:
199+
"""
200+
Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
201+
"""
153202
return llama_cpp.llama_model_is_diffusion(self.model)
154203

155-
def rope_freq_scale_train(self) -> float:
156-
return llama_cpp.llama_model_rope_freq_scale_train(self.model)
157-
158-
def desc(self) -> str:
159-
buf = ctypes.create_string_buffer(1024)
160-
llama_cpp.llama_model_desc(self.model, buf, 1024)
161-
return buf.value.decode("utf-8")
162-
163-
def size(self) -> int:
164-
return llama_cpp.llama_model_size(self.model)
165-
166-
def get_tensor(self, name: str) -> ctypes.c_void_p:
167-
raise NotImplementedError("get_tensor is not implemented in llama.cpp")
168-
169204
# Vocab
170205

171206
def token_get_text(self, token: int) -> str:

llama_cpp/llama.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,13 +696,20 @@ def __init__(
696696

697697
try:
698698
self.metadata = self._model.metadata()
699+
self.model_desc = self._model.model_desc()
700+
# The total size of all the tensors in the model in bytes
701+
self.model_size = self._model.model_size()
702+
699703
except Exception as e:
700704
self.metadata = {}
701705
if self.verbose:
702706
print(f"Failed to load metadata: {e}", file=sys.stderr)
703707

704708
if self.verbose:
705-
print(f"Model metadata: {self.metadata}", file=sys.stderr)
709+
print(f"Model desc: {self.model_desc}, "
710+
f"Model size: {self.model_size / (1024 * 1024):.2f} MB, "
711+
f"Model metadata: {self.metadata}",
712+
file=sys.stderr)
706713

707714
eos_token_id = self.token_eos()
708715
bos_token_id = self.token_bos()

llama_cpp/llama_embedding.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -303,9 +303,7 @@ def rank(self, query: str, documents: List[str]) -> List[float]:
303303

304304
# 1. Attempt to retrieve the built-in 'rerank' chat template from model metadata.
305305
# Modern GGUF models often include a template for formatting query/document pairs.
306-
rerank_template = llama_cpp.llama_model_chat_template(self._model.model, b"rerank")
307-
if rerank_template:
308-
rerank_template = rerank_template.decode("utf-8")
306+
rerank_template = self._model.model_chat_template(b"rerank")
309307

310308
batch_inputs: List[List[int]] = []
311309

0 commit comments

Comments
 (0)