Skip to content

Commit 5ef874c

Browse files
committed
Improve sampling and grammar lifecycle management, fix memory growth issues
- Validate grammar sampler initialization and inputs - Replace unbounded prev token list with bounded deque by LlamaSamplingParams n_prev param - Reuse logits NumPy view to avoid repeated allocations - Reuse single-token buffers for grammar rejection sampling - Minor cleanups and consistency improvements in sampling flow
1 parent af9d925 commit 5ef874c

1 file changed

Lines changed: 110 additions & 61 deletions

File tree

llama_cpp/_internals.py

Lines changed: 110 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818

1919
from dataclasses import dataclass, field
20+
from collections import deque
2021
from contextlib import ExitStack
2122

2223
import numpy as np
@@ -871,15 +872,24 @@ class GrammarSampler:
871872

872873
def __init__(self, model, grammar_str, lazy=False, triggers=None):
873874

875+
if model is None:
876+
raise ValueError("model must not be None")
877+
874878
self.model = model
875879
self.vocab = model.vocab
876880

881+
if not grammar_str:
882+
raise ValueError("grammar_str must not be empty")
883+
877884
self.grammar = llama_cpp.llama_sampler_init_grammar(
878885
self.vocab,
879-
grammar_str.encode(),
886+
grammar_str.encode("utf-8"),
880887
b"root"
881888
)
882889

890+
if not self.grammar:
891+
raise RuntimeError("Failed to initialize grammar sampler")
892+
883893
def apply(self, token_data):
884894
llama_cpp.llama_sampler_apply(self.grammar, token_data)
885895

@@ -889,8 +899,22 @@ def accept(self, token):
889899
def reset(self):
890900
llama_cpp.llama_sampler_reset(self.grammar)
891901

892-
def free(self):
893-
llama_cpp.llama_sampler_free(self.grammar)
902+
def close(self):
903+
if self.grammar:
904+
try:
905+
llama_cpp.llama_sampler_free(self.grammar)
906+
except Exception:
907+
pass
908+
909+
self.model = None
910+
self.vocab = None
911+
self.grammar = None
912+
913+
def __del__(self):
914+
try:
915+
self.close()
916+
except Exception:
917+
pass
894918

895919
@dataclass
896920
class LlamaSamplingContext:
@@ -904,35 +928,50 @@ def __init__(
904928
model: Optional[LlamaModel] = None,
905929
_existing_sampler: Optional[LlamaSampler] = None, # Internal use for cloning
906930
):
907-
self.params = params
931+
if model is None:
932+
raise RuntimeError("model must not be None")
908933
self.model = model
934+
935+
self.params = params
909936
self.vocab = llama_cpp.llama_model_get_vocab(model.model)
910937
self.n_vocab = model.n_vocab()
911938

912939
lparams = llama_cpp.llama_sampler_chain_default_params()
913940
lparams.no_perf = params.no_perf
914941

915-
# Keep track of generated tokens for Python-side debugging/decoding
916-
self.prev: List[int] = []
942+
# history (bounded)
943+
# params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
944+
self.prev = deque(maxlen=max(params.n_prev, params.penalty_last_n))
945+
# reusable token data array
917946
self._cur_p = LlamaTokenDataArray(n_vocab=self.n_vocab)
947+
# reusable numpy logits view
948+
self._logits_view = None
918949

950+
self._single_token = llama_cpp.llama_token_data()
951+
self._single_array = llama_cpp.llama_token_data_array(
952+
data=ctypes.pointer(self._single_token),
953+
size=1,
954+
selected=-1,
955+
sorted=False,
956+
)
957+
958+
# sampler chain
919959
if _existing_sampler:
920-
# Use the provided sampler (already configured/cloned)
921960
self.sampler_chain = _existing_sampler
922961
else:
923-
# Build a new chain from scratch
924-
self.grammar_sampler = None
925962
self.sampler_chain = LlamaSampler()
926-
927-
if params.grammar is not None:
928-
self.grammar_sampler = GrammarSampler(
929-
model,
930-
params.grammar,
931-
params.grammar_lazy,
932-
params.grammar_triggers
933-
)
934963
self._build_sampler_chain()
935964

965+
# grammar sampler
966+
self.grammar_sampler = None
967+
if params.grammar:
968+
self.grammar_sampler = GrammarSampler(
969+
model,
970+
params.grammar,
971+
params.grammar_lazy,
972+
params.grammar_triggers,
973+
)
974+
936975
def _build_sampler_chain(self):
937976
"""
938977
Build sampler chain aligned with llama.cpp common_sampler_init
@@ -1029,9 +1068,13 @@ def reset(self):
10291068
"""
10301069
Resets the internal state of all samplers in the chain.
10311070
"""
1032-
self.grammar_sampler.reset()
1033-
self.sampler_chain.reset()
1034-
self.prev = []
1071+
self.prev.clear()
1072+
1073+
if self.grammar_sampler:
1074+
self.grammar_sampler.reset()
1075+
1076+
if self.sampler_chain:
1077+
self.sampler_chain.reset()
10351078

10361079
def cp(self) -> 'LlamaSamplingContext':
10371080
"""
@@ -1084,14 +1127,17 @@ def sample(
10841127
return int(sampled)
10851128

10861129
# 3. build cur_p
1087-
logits = llama_cpp.llama_get_logits_ith(ctx.ctx, idx)
1130+
logits_ptr = llama_cpp.llama_get_logits_ith(ctx.ctx, idx)
10881131

1089-
logits_array = np.ctypeslib.as_array(
1090-
logits,
1091-
shape=(self.n_vocab,)
1092-
)
1132+
if self._logits_view is None:
1133+
self._logits_view = np.ctypeslib.as_array(
1134+
logits_ptr,
1135+
shape=(self.n_vocab,),
1136+
)
10931137

1138+
logits_array = self._logits_view
10941139
cur_p = self._cur_p
1140+
10951141
cur_p.copy_logits(logits_array)
10961142

10971143
# logit bias
@@ -1107,6 +1153,15 @@ def sample(
11071153
ctypes.byref(cur_p.candidates)
11081154
)
11091155

1156+
llama_cpp.llama_sampler_apply(
1157+
self.sampler_chain.sampler,
1158+
ctypes.byref(cur_p.candidates)
1159+
)
1160+
# grammar-first return directly
1161+
selected = cur_p.candidates.selected
1162+
return int(cur_p.candidates_data.id[selected])
1163+
1164+
11101165
# 5. sampling chain
11111166
llama_cpp.llama_sampler_apply(
11121167
self.sampler_chain.sampler,
@@ -1116,45 +1171,24 @@ def sample(
11161171
selected = cur_p.candidates.selected
11171172
token = int(cur_p.candidates_data.id[selected])
11181173

1119-
# 6. grammar-first return directly
1120-
if self.grammar_sampler and grammar_first:
1121-
return token
1122-
1123-
# 7. grammar rejection sampling
1174+
# 6. grammar rejection sampling
11241175
if self.grammar_sampler:
11251176

1126-
single = llama_cpp.llama_token_data(
1127-
id=token,
1128-
logit=1.0,
1129-
p=0.0
1130-
)
1131-
1132-
single_arr = llama_cpp.llama_token_data_array(
1133-
data=ctypes.pointer(single),
1134-
size=1,
1135-
selected=-1,
1136-
sorted=False
1137-
)
1177+
self._single_token.id = token
1178+
self._single_token.logit = 1.0
1179+
self._single_token.p = 0.0
1180+
self._single_array.selected = -1
1181+
self._single_array.sorted = False
11381182

11391183
llama_cpp.llama_sampler_apply(
11401184
self.grammar_sampler.grammar,
1141-
ctypes.byref(single_arr)
1185+
ctypes.byref(self._single_array)
11421186
)
11431187

1144-
valid = not np.isneginf(single.logit)
1145-
1146-
if valid:
1188+
if not np.isneginf(self._single_token.logit):
11471189
return token
11481190

1149-
1150-
# 8. resample
1151-
logits = llama_cpp.llama_get_logits_ith(ctx.ctx, idx)
1152-
1153-
logits_array = np.ctypeslib.as_array(
1154-
logits,
1155-
shape=(self.n_vocab,)
1156-
)
1157-
1191+
# 7. resample
11581192
cur_p.copy_logits(logits_array)
11591193

11601194
llama_cpp.llama_sampler_apply(
@@ -1172,14 +1206,29 @@ def sample(
11721206

11731207
return token
11741208

1209+
def close(self):
1210+
"""
1211+
Clear samplers cache
1212+
"""
1213+
if self.grammar_sampler:
1214+
self.grammar_sampler.close()
1215+
self.grammar_sampler = None
1216+
1217+
if self.sampler_chain:
1218+
self.sampler_chain.close()
1219+
self.sampler_chain = None
1220+
1221+
def __del__(self):
1222+
try:
1223+
self.close()
1224+
except Exception:
1225+
pass
1226+
11751227
# --- Utilities ---
11761228

11771229
def last(self) -> Optional[int]:
11781230
"""Returns the last sampled token."""
1179-
if len(self.prev) > 0:
1180-
return self.prev[-1]
1181-
else:
1182-
return None
1231+
return self.prev[-1] if self.prev else None
11831232

11841233
def prev_str(self, ctx_main: LlamaContext, n: int) -> str:
11851234
"""
@@ -1189,9 +1238,9 @@ def prev_str(self, ctx_main: LlamaContext, n: int) -> str:
11891238
if not self.prev:
11901239
return ""
11911240
# Get the last n tokens
1192-
last_tokens = self.prev[-n:]
1241+
last_n_tokens = self.prev[-n:]
11931242
# Use the model linked to the context to detokenize
1194-
return ctx_main.model.detokenize(last_tokens).decode("utf-8", errors="replace")
1243+
return ctx_main.model.detokenize(last_n_tokens).decode("utf-8", errors="replace")
11951244

11961245

11971246
class CustomSampler:

0 commit comments

Comments
 (0)