1717)
1818
1919from dataclasses import dataclass , field
20+ from collections import deque
2021from contextlib import ExitStack
2122
2223import numpy as np
@@ -871,15 +872,24 @@ class GrammarSampler:
871872
872873 def __init__ (self , model , grammar_str , lazy = False , triggers = None ):
873874
875+ if model is None :
876+ raise ValueError ("model must not be None" )
877+
874878 self .model = model
875879 self .vocab = model .vocab
876880
881+ if not grammar_str :
882+ raise ValueError ("grammar_str must not be empty" )
883+
877884 self .grammar = llama_cpp .llama_sampler_init_grammar (
878885 self .vocab ,
879- grammar_str .encode (),
886+ grammar_str .encode ("utf-8" ),
880887 b"root"
881888 )
882889
890+ if not self .grammar :
891+ raise RuntimeError ("Failed to initialize grammar sampler" )
892+
883893 def apply (self , token_data ):
884894 llama_cpp .llama_sampler_apply (self .grammar , token_data )
885895
@@ -889,8 +899,22 @@ def accept(self, token):
889899 def reset (self ):
890900 llama_cpp .llama_sampler_reset (self .grammar )
891901
892- def free (self ):
893- llama_cpp .llama_sampler_free (self .grammar )
902+ def close (self ):
903+ if self .grammar :
904+ try :
905+ llama_cpp .llama_sampler_free (self .grammar )
906+ except Exception :
907+ pass
908+
909+ self .model = None
910+ self .vocab = None
911+ self .grammar = None
912+
913+ def __del__ (self ):
914+ try :
915+ self .close ()
916+ except Exception :
917+ pass
894918
895919@dataclass
896920class LlamaSamplingContext :
@@ -904,35 +928,50 @@ def __init__(
904928 model : Optional [LlamaModel ] = None ,
905929 _existing_sampler : Optional [LlamaSampler ] = None , # Internal use for cloning
906930 ):
907- self .params = params
931+ if model is None :
932+ raise RuntimeError ("model must not be None" )
908933 self .model = model
934+
935+ self .params = params
909936 self .vocab = llama_cpp .llama_model_get_vocab (model .model )
910937 self .n_vocab = model .n_vocab ()
911938
912939 lparams = llama_cpp .llama_sampler_chain_default_params ()
913940 lparams .no_perf = params .no_perf
914941
915- # Keep track of generated tokens for Python-side debugging/decoding
916- self .prev : List [int ] = []
942+ # history (bounded)
943+ # params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
944+ self .prev = deque (maxlen = max (params .n_prev , params .penalty_last_n ))
945+ # reusable token data array
917946 self ._cur_p = LlamaTokenDataArray (n_vocab = self .n_vocab )
947+ # reusable numpy logits view
948+ self ._logits_view = None
918949
950+ self ._single_token = llama_cpp .llama_token_data ()
951+ self ._single_array = llama_cpp .llama_token_data_array (
952+ data = ctypes .pointer (self ._single_token ),
953+ size = 1 ,
954+ selected = - 1 ,
955+ sorted = False ,
956+ )
957+
958+ # sampler chain
919959 if _existing_sampler :
920- # Use the provided sampler (already configured/cloned)
921960 self .sampler_chain = _existing_sampler
922961 else :
923- # Build a new chain from scratch
924- self .grammar_sampler = None
925962 self .sampler_chain = LlamaSampler ()
926-
927- if params .grammar is not None :
928- self .grammar_sampler = GrammarSampler (
929- model ,
930- params .grammar ,
931- params .grammar_lazy ,
932- params .grammar_triggers
933- )
934963 self ._build_sampler_chain ()
935964
965+ # grammar sampler
966+ self .grammar_sampler = None
967+ if params .grammar :
968+ self .grammar_sampler = GrammarSampler (
969+ model ,
970+ params .grammar ,
971+ params .grammar_lazy ,
972+ params .grammar_triggers ,
973+ )
974+
936975 def _build_sampler_chain (self ):
937976 """
938977 Build sampler chain aligned with llama.cpp common_sampler_init
@@ -1029,9 +1068,13 @@ def reset(self):
10291068 """
10301069 Resets the internal state of all samplers in the chain.
10311070 """
1032- self .grammar_sampler .reset ()
1033- self .sampler_chain .reset ()
1034- self .prev = []
1071+ self .prev .clear ()
1072+
1073+ if self .grammar_sampler :
1074+ self .grammar_sampler .reset ()
1075+
1076+ if self .sampler_chain :
1077+ self .sampler_chain .reset ()
10351078
10361079 def cp (self ) -> 'LlamaSamplingContext' :
10371080 """
@@ -1084,14 +1127,17 @@ def sample(
10841127 return int (sampled )
10851128
10861129 # 3. build cur_p
1087- logits = llama_cpp .llama_get_logits_ith (ctx .ctx , idx )
1130+ logits_ptr = llama_cpp .llama_get_logits_ith (ctx .ctx , idx )
10881131
1089- logits_array = np .ctypeslib .as_array (
1090- logits ,
1091- shape = (self .n_vocab ,)
1092- )
1132+ if self ._logits_view is None :
1133+ self ._logits_view = np .ctypeslib .as_array (
1134+ logits_ptr ,
1135+ shape = (self .n_vocab ,),
1136+ )
10931137
1138+ logits_array = self ._logits_view
10941139 cur_p = self ._cur_p
1140+
10951141 cur_p .copy_logits (logits_array )
10961142
10971143 # logit bias
@@ -1107,6 +1153,15 @@ def sample(
11071153 ctypes .byref (cur_p .candidates )
11081154 )
11091155
1156+ llama_cpp .llama_sampler_apply (
1157+ self .sampler_chain .sampler ,
1158+ ctypes .byref (cur_p .candidates )
1159+ )
1160+ # grammar-first return directly
1161+ selected = cur_p .candidates .selected
1162+ return int (cur_p .candidates_data .id [selected ])
1163+
1164+
11101165 # 5. sampling chain
11111166 llama_cpp .llama_sampler_apply (
11121167 self .sampler_chain .sampler ,
@@ -1116,45 +1171,24 @@ def sample(
11161171 selected = cur_p .candidates .selected
11171172 token = int (cur_p .candidates_data .id [selected ])
11181173
1119- # 6. grammar-first return directly
1120- if self .grammar_sampler and grammar_first :
1121- return token
1122-
1123- # 7. grammar rejection sampling
1174+ # 6. grammar rejection sampling
11241175 if self .grammar_sampler :
11251176
1126- single = llama_cpp .llama_token_data (
1127- id = token ,
1128- logit = 1.0 ,
1129- p = 0.0
1130- )
1131-
1132- single_arr = llama_cpp .llama_token_data_array (
1133- data = ctypes .pointer (single ),
1134- size = 1 ,
1135- selected = - 1 ,
1136- sorted = False
1137- )
1177+ self ._single_token .id = token
1178+ self ._single_token .logit = 1.0
1179+ self ._single_token .p = 0.0
1180+ self ._single_array .selected = - 1
1181+ self ._single_array .sorted = False
11381182
11391183 llama_cpp .llama_sampler_apply (
11401184 self .grammar_sampler .grammar ,
1141- ctypes .byref (single_arr )
1185+ ctypes .byref (self . _single_array )
11421186 )
11431187
1144- valid = not np .isneginf (single .logit )
1145-
1146- if valid :
1188+ if not np .isneginf (self ._single_token .logit ):
11471189 return token
11481190
1149-
1150- # 8. resample
1151- logits = llama_cpp .llama_get_logits_ith (ctx .ctx , idx )
1152-
1153- logits_array = np .ctypeslib .as_array (
1154- logits ,
1155- shape = (self .n_vocab ,)
1156- )
1157-
1191+ # 7. resample
11581192 cur_p .copy_logits (logits_array )
11591193
11601194 llama_cpp .llama_sampler_apply (
@@ -1172,14 +1206,29 @@ def sample(
11721206
11731207 return token
11741208
1209+ def close (self ):
1210+ """
1211+ Clear samplers cache
1212+ """
1213+ if self .grammar_sampler :
1214+ self .grammar_sampler .close ()
1215+ self .grammar_sampler = None
1216+
1217+ if self .sampler_chain :
1218+ self .sampler_chain .close ()
1219+ self .sampler_chain = None
1220+
1221+ def __del__ (self ):
1222+ try :
1223+ self .close ()
1224+ except Exception :
1225+ pass
1226+
11751227 # --- Utilities ---
11761228
11771229 def last (self ) -> Optional [int ]:
11781230 """Returns the last sampled token."""
1179- if len (self .prev ) > 0 :
1180- return self .prev [- 1 ]
1181- else :
1182- return None
1231+ return self .prev [- 1 ] if self .prev else None
11831232
11841233 def prev_str (self , ctx_main : LlamaContext , n : int ) -> str :
11851234 """
@@ -1189,9 +1238,9 @@ def prev_str(self, ctx_main: LlamaContext, n: int) -> str:
11891238 if not self .prev :
11901239 return ""
11911240 # Get the last n tokens
1192- last_tokens = self .prev [- n :]
1241+ last_n_tokens = self .prev [- n :]
11931242 # Use the model linked to the context to detokenize
1194- return ctx_main .model .detokenize (last_tokens ).decode ("utf-8" , errors = "replace" )
1243+ return ctx_main .model .detokenize (last_n_tokens ).decode ("utf-8" , errors = "replace" )
11951244
11961245
11971246class CustomSampler :
0 commit comments