Skip to content

Commit 5364cf9

Browse files
committed
feat(LlamaContext): add safety checks and docstrings to logits retrieval
- Add explicit null pointer validation to `get_logits` and `get_logits_ith`. These methods now raise a `RuntimeError` instead of silently returning invalid pointers when logits are unavailable or the index is out of bounds. - Add comprehensive docstrings to both methods, detailing the underlying buffer shape and memory layout. - Include a performance warning in `get_logits_ith` about the internal synchronization/reordering overhead to discourage its use on the hot path. Signed-off-by: JamePeng <jame_peng@sina.com>
1 parent d90895d commit 5364cf9

1 file changed

Lines changed: 26 additions & 2 deletions

File tree

llama_cpp/_internals.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -755,12 +755,36 @@ def synchronize(self):
755755
llama_cpp.llama_synchronize(self.ctx)
756756

757757
def get_logits(self):
758+
"""
759+
Token logits obtained from the last call to llama_decode()
760+
The logits for which llama_batch.logits[i] != 0 are stored contiguously
761+
in the order they have appeared in the batch.
762+
Rows: number of tokens for which llama_batch.logits[i] != 0
763+
Cols: n_vocab
764+
765+
Returns:
766+
Pointer to the logits buffer of shape (n_tokens, n_vocab)
767+
"""
758768
self._assert_ctx()
759-
return llama_cpp.llama_get_logits(self.ctx)
769+
logits = llama_cpp.llama_get_logits(self.ctx)
770+
if not logits:
771+
raise RuntimeError(f"LlamaContext.get_logits: failed to get logits")
772+
return logits
760773

761774
def get_logits_ith(self, i: int):
775+
"""
776+
Return logits for the ith output row from the last llama_decode call.
777+
778+
Note:
779+
This calls llama_get_logits_ith(), which may reorder/synchronize
780+
the output buffer internally. Avoid calling it on the hot path unless
781+
Python-side logits are required.
782+
"""
762783
self._assert_ctx()
763-
return llama_cpp.llama_get_logits_ith(self.ctx, i)
784+
logits = llama_cpp.llama_get_logits_ith(self.ctx, i)
785+
if not logits:
786+
raise RuntimeError(f"LlamaContext.get_logits_ith: invalid logits index {i}")
787+
return logits
764788

765789
def set_embeddings(self, embeddings: bool):
766790
self._assert_ctx()

0 commit comments

Comments
 (0)