@@ -525,9 +525,8 @@ def free_lora_adapter():
525525
526526 self .n_tokens = 0
527527 self .input_ids : npt .NDArray [np .intc ] = np .ndarray ((n_ctx ,), dtype = np .intc )
528- self .scores : npt .NDArray [np .single ] = np .ndarray (
529- (n_ctx if logits_all == True else n_batch , self ._n_vocab ), dtype = np .single
530- )
528+ self .scores : npt .NDArray [np .single ] = np .ndarray ((n_ctx if self ._logits_all else 1 , self ._n_vocab ), dtype = np .single )
529+
531530
532531 self ._mirostat_mu = ctypes .c_float (
533532 2.0 * 5.0
@@ -638,7 +637,10 @@ def _input_ids(self) -> npt.NDArray[np.intc]:
638637
639638 @property
640639 def _scores (self ) -> npt .NDArray [np .single ]:
641- return self .scores [: self .n_tokens , :]
640+ if self ._logits_all :
641+ return self .scores [: self .n_tokens , :]
642+ else :
643+ return self .scores
642644
643645 @property
644646 def eval_tokens (self ) -> Deque [int ]:
@@ -747,14 +749,17 @@ def eval(self, tokens: Sequence[int]):
747749 ) from e
748750 # Save tokens
749751 self .input_ids [n_past : n_past + n_batch_tokens ] = batch
752+
750753 # Save logits
754+ logits_ptr = self ._ctx .get_logits ()
751755 if self ._logits_all :
752756 rows = n_batch_tokens
753757 cols = self ._n_vocab
754- logits = np .ctypeslib .as_array (
755- self ._ctx .get_logits (), shape = (rows * cols ,)
756- )
757- self .scores [n_past : n_past + n_batch_tokens , :].reshape (- 1 )[::] = logits
758+ logits_view = np .ctypeslib .as_array (logits_ptr , shape = (rows * cols ,))
759+ self .scores [n_past : n_past + n_batch_tokens , :].reshape (- 1 )[:] = logits_view
760+ else :
761+ logits_view = np .ctypeslib .as_array (logits_ptr , shape = (self ._n_vocab ,))
762+ self .scores [0 , :] = logits_view
758763
759764 # Update n_tokens
760765 current_pos += n_batch_tokens
@@ -875,7 +880,10 @@ def sample(
875880 # LogitsProcessor Adapter
876881 if logits_processor :
877882 def adapter (token_data_array : llama_cpp .llama_token_data_array ):
878- current_scores = self ._scores [self .n_tokens - 1 , :]
883+ if self ._logits_all :
884+ current_scores = self ._scores [self .n_tokens - 1 , :]
885+ else :
886+ current_scores = self ._scores [0 , :]
879887 new_scores = logits_processor (self ._input_ids , current_scores )
880888 size = token_data_array .size
881889 data_ptr = token_data_array .data
@@ -1003,7 +1011,10 @@ def generate(
10031011
10041012 if logits_processor :
10051013 def adapter (token_data_array : llama_cpp .llama_token_data_array ):
1006- current_scores = self ._scores [self .n_tokens - 1 , :]
1014+ if self ._logits_all :
1015+ current_scores = self ._scores [self .n_tokens - 1 , :]
1016+ else :
1017+ current_scores = self ._scores [0 , :]
10071018 new_scores = logits_processor (self ._input_ids , current_scores )
10081019
10091020 size = token_data_array .size
@@ -1050,10 +1061,22 @@ def adapter(token_data_array: llama_cpp.llama_token_data_array):
10501061 self ._sampling_ctx .accept (token , False if grammar is None else True )
10511062
10521063 sample_idx += 1
1053- if stopping_criteria is not None and stopping_criteria (
1054- self ._input_ids [: sample_idx ], self ._scores [sample_idx - self .n_tokens , :]
1055- ):
1056- return
1064+ if stopping_criteria is not None :
1065+ if self ._logits_all :
1066+ logits_idx = sample_idx - self .n_tokens
1067+ check_stopping = True
1068+ else :
1069+ if sample_idx == self .n_tokens :
1070+ logits_idx = 0
1071+ check_stopping = True
1072+ else :
1073+ check_stopping = False
1074+
1075+ if check_stopping and stopping_criteria (
1076+ self ._input_ids [: sample_idx ],
1077+ self ._scores [logits_idx , :]
1078+ ):
1079+ return
10571080 tokens_or_none = yield token
10581081 tokens .clear ()
10591082 tokens .append (token )
@@ -1556,7 +1579,10 @@ def _create_completion(
15561579 ).decode ("utf-8" , errors = "ignore" )
15571580 )
15581581 token_offset = len (prompt_tokens ) + returned_tokens
1559- logits = self ._scores [token_offset - 1 , :]
1582+ if self ._logits_all :
1583+ logits = self ._scores [token_offset - 1 , :]
1584+ else :
1585+ logits = self ._scores [0 , :]
15601586 current_logprobs = Llama .logits_to_logprobs (logits ).tolist ()
15611587 sorted_logprobs = list (
15621588 sorted (
@@ -1695,7 +1721,10 @@ def _create_completion(
16951721 )
16961722 )
16971723 token_offset = len (prompt_tokens ) + returned_tokens - 1
1698- logits = self ._scores [token_offset , :]
1724+ if self ._logits_all :
1725+ logits = self ._scores [token_offset , :]
1726+ else :
1727+ logits = self ._scores [0 , :]
16991728 current_logprobs = Llama .logits_to_logprobs (logits ).tolist ()
17001729 sorted_logprobs = list (
17011730 sorted (
@@ -2406,46 +2435,72 @@ def __setstate__(self, state):
24062435 def save_state (self ) -> LlamaState :
24072436 if self .verbose :
24082437 print ("Llama.save_state: saving llama state" , file = sys .stderr )
2409- state_size = llama_cpp .llama_get_state_size (self ._ctx .ctx )
2438+
2439+ # Query the backend for the required buffer size to store the current state.
2440+ state_size = llama_cpp .llama_state_get_size (self ._ctx .ctx )
24102441 if self .verbose :
24112442 print (f"Llama.save_state: got state size: { state_size } " , file = sys .stderr )
2443+
2444+ # Allocate a ctypes uint8 array (buffer) of the required size.
24122445 llama_state = (ctypes .c_uint8 * int (state_size ))()
24132446 if self .verbose :
24142447 print ("Llama.save_state: allocated state" , file = sys .stderr )
2415- n_bytes = llama_cpp .llama_copy_state_data (self ._ctx .ctx , llama_state )
2448+
2449+ # Copy the raw state data from the internal C context into our Python-managed buffer.
2450+ # Returns the actual number of bytes written (n_bytes).
2451+ n_bytes = llama_cpp .llama_state_get_data (self ._ctx .ctx , llama_state , state_size )
24162452 if self .verbose :
24172453 print (f"Llama.save_state: copied llama state: { n_bytes } " , file = sys .stderr )
2454+
2455+ # Safety check to prevent buffer overflow issues.
24182456 if int (n_bytes ) > int (state_size ):
24192457 raise RuntimeError ("Failed to copy llama state data" )
2420- llama_state_compact = (ctypes .c_uint8 * int (n_bytes ))()
2421- llama_cpp .ctypes .memmove (llama_state_compact , llama_state , int (n_bytes ))
2458+
2459+ # Directly read 'n_bytes' from the buffer's memory address to create the Python bytes object.
2460+ # Significantly reducing memory overhead by avoiding an intermediate array allocation.
2461+ llama_state_bytes = ctypes .string_at (ctypes .addressof (llama_state ), int (n_bytes ))
24222462 if self .verbose :
24232463 print (
24242464 f"Llama.save_state: saving { n_bytes } bytes of llama state" ,
24252465 file = sys .stderr ,
24262466 )
2467+
2468+ # Create and return the snapshot object.
24272469 return LlamaState (
24282470 scores = self ._scores .copy (),
24292471 input_ids = self .input_ids .copy (),
24302472 n_tokens = self .n_tokens ,
2431- llama_state = bytes ( llama_state_compact ) ,
2473+ llama_state = llama_state_bytes ,
24322474 llama_state_size = n_bytes ,
24332475 seed = self ._seed ,
24342476 )
24352477
24362478 def load_state (self , state : LlamaState ) -> None :
2437- # Only filling in up to `n_tokens` and then zero-ing out the rest
2438- self .scores [: state .n_tokens , :] = state .scores .copy ()
2439- rest = self .scores [state .n_tokens :, :]
2440- rest [rest > 0 ] = 0.0
2479+ # Restore metadata: input tokens, token count, and RNG seed.
24412480 self .input_ids = state .input_ids .copy ()
24422481 self .n_tokens = state .n_tokens
24432482 self ._seed = state .seed
2483+ # Restore Logits (Scores) handling different memory configurations.
2484+ if self ._logits_all :
2485+ # Case A: Full history mode. Restore as many rows as possible.
2486+ available_rows = state .scores .shape [0 ]
2487+ # Prevent index out of bounds by taking the minimum valid length.
2488+ limit = min (self .n_tokens , available_rows )
2489+ # Restore valid history and clear any remaining "future" slots.
2490+ self .scores [:limit , :] = state .scores [:limit , :]
2491+ self .scores [limit :, :] = 0.0
2492+ else :
2493+ # Case B: Optimized mode (1-row buffer).
2494+ # Only restore the last token's logits if available.
2495+ if state .scores .shape [0 ] > 0 :
2496+ self .scores [0 , :] = state .scores [- 1 , :]
2497+
24442498 state_size = state .llama_state_size
24452499 LLamaStateArrayType = ctypes .c_uint8 * state_size
2500+ # Copy the raw bytes from the Python object into a C-compatible buffer.
24462501 llama_state = LLamaStateArrayType .from_buffer_copy (state .llama_state )
24472502
2448- if llama_cpp .llama_set_state_data (self ._ctx .ctx , llama_state ) != state_size :
2503+ if llama_cpp .llama_state_set_data (self ._ctx .ctx , llama_state , state_size ) != state_size :
24492504 raise RuntimeError ("Failed to set llama state data" )
24502505
24512506 def n_ctx (self ) -> int :
0 commit comments