Skip to content

Commit f9f8669

Browse files
committed
optimize(memory): reduce scores buffer size and optimize state saving
- Update save_state and load_state API use. - Refactored self.scores to allocate only a single row (1, n_vocab) when logits_all=False, significantly reducing memory usage for large vocabulary models. - Optimized save_state to eliminate redundant memory allocations and copies by using ctypes.string_at. - Updated load_state, eval, and sampler adapters to correctly handle the dynamic shape of self.scores.
1 parent bb8437a commit f9f8669

1 file changed

Lines changed: 81 additions & 26 deletions

File tree

llama_cpp/llama.py

Lines changed: 81 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -525,9 +525,8 @@ def free_lora_adapter():
525525

526526
self.n_tokens = 0
527527
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
528-
self.scores: npt.NDArray[np.single] = np.ndarray(
529-
(n_ctx if logits_all == True else n_batch, self._n_vocab), dtype=np.single
530-
)
528+
self.scores: npt.NDArray[np.single] = np.ndarray((n_ctx if self._logits_all else 1, self._n_vocab), dtype=np.single)
529+
531530

532531
self._mirostat_mu = ctypes.c_float(
533532
2.0 * 5.0
@@ -638,7 +637,10 @@ def _input_ids(self) -> npt.NDArray[np.intc]:
638637

639638
@property
640639
def _scores(self) -> npt.NDArray[np.single]:
641-
return self.scores[: self.n_tokens, :]
640+
if self._logits_all:
641+
return self.scores[: self.n_tokens, :]
642+
else:
643+
return self.scores
642644

643645
@property
644646
def eval_tokens(self) -> Deque[int]:
@@ -747,14 +749,17 @@ def eval(self, tokens: Sequence[int]):
747749
) from e
748750
# Save tokens
749751
self.input_ids[n_past : n_past + n_batch_tokens] = batch
752+
750753
# Save logits
754+
logits_ptr = self._ctx.get_logits()
751755
if self._logits_all:
752756
rows = n_batch_tokens
753757
cols = self._n_vocab
754-
logits = np.ctypeslib.as_array(
755-
self._ctx.get_logits(), shape=(rows * cols,)
756-
)
757-
self.scores[n_past : n_past + n_batch_tokens, :].reshape(-1)[::] = logits
758+
logits_view = np.ctypeslib.as_array(logits_ptr, shape=(rows * cols,))
759+
self.scores[n_past : n_past + n_batch_tokens, :].reshape(-1)[:] = logits_view
760+
else:
761+
logits_view = np.ctypeslib.as_array(logits_ptr, shape=(self._n_vocab,))
762+
self.scores[0, :] = logits_view
758763

759764
# Update n_tokens
760765
current_pos += n_batch_tokens
@@ -875,7 +880,10 @@ def sample(
875880
# LogitsProcessor Adapter
876881
if logits_processor:
877882
def adapter(token_data_array: llama_cpp.llama_token_data_array):
878-
current_scores = self._scores[self.n_tokens - 1, :]
883+
if self._logits_all:
884+
current_scores = self._scores[self.n_tokens - 1, :]
885+
else:
886+
current_scores = self._scores[0, :]
879887
new_scores = logits_processor(self._input_ids, current_scores)
880888
size = token_data_array.size
881889
data_ptr = token_data_array.data
@@ -1003,7 +1011,10 @@ def generate(
10031011

10041012
if logits_processor:
10051013
def adapter(token_data_array: llama_cpp.llama_token_data_array):
1006-
current_scores = self._scores[self.n_tokens - 1, :]
1014+
if self._logits_all:
1015+
current_scores = self._scores[self.n_tokens - 1, :]
1016+
else:
1017+
current_scores = self._scores[0, :]
10071018
new_scores = logits_processor(self._input_ids, current_scores)
10081019

10091020
size = token_data_array.size
@@ -1050,10 +1061,22 @@ def adapter(token_data_array: llama_cpp.llama_token_data_array):
10501061
self._sampling_ctx.accept(token, False if grammar is None else True)
10511062

10521063
sample_idx += 1
1053-
if stopping_criteria is not None and stopping_criteria(
1054-
self._input_ids[: sample_idx], self._scores[sample_idx - self.n_tokens, :]
1055-
):
1056-
return
1064+
if stopping_criteria is not None:
1065+
if self._logits_all:
1066+
logits_idx = sample_idx - self.n_tokens
1067+
check_stopping = True
1068+
else:
1069+
if sample_idx == self.n_tokens:
1070+
logits_idx = 0
1071+
check_stopping = True
1072+
else:
1073+
check_stopping = False
1074+
1075+
if check_stopping and stopping_criteria(
1076+
self._input_ids[: sample_idx],
1077+
self._scores[logits_idx, :]
1078+
):
1079+
return
10571080
tokens_or_none = yield token
10581081
tokens.clear()
10591082
tokens.append(token)
@@ -1556,7 +1579,10 @@ def _create_completion(
15561579
).decode("utf-8", errors="ignore")
15571580
)
15581581
token_offset = len(prompt_tokens) + returned_tokens
1559-
logits = self._scores[token_offset - 1, :]
1582+
if self._logits_all:
1583+
logits = self._scores[token_offset - 1, :]
1584+
else:
1585+
logits = self._scores[0, :]
15601586
current_logprobs = Llama.logits_to_logprobs(logits).tolist()
15611587
sorted_logprobs = list(
15621588
sorted(
@@ -1695,7 +1721,10 @@ def _create_completion(
16951721
)
16961722
)
16971723
token_offset = len(prompt_tokens) + returned_tokens - 1
1698-
logits = self._scores[token_offset, :]
1724+
if self._logits_all:
1725+
logits = self._scores[token_offset, :]
1726+
else:
1727+
logits = self._scores[0, :]
16991728
current_logprobs = Llama.logits_to_logprobs(logits).tolist()
17001729
sorted_logprobs = list(
17011730
sorted(
@@ -2406,46 +2435,72 @@ def __setstate__(self, state):
24062435
def save_state(self) -> LlamaState:
24072436
if self.verbose:
24082437
print("Llama.save_state: saving llama state", file=sys.stderr)
2409-
state_size = llama_cpp.llama_get_state_size(self._ctx.ctx)
2438+
2439+
# Query the backend for the required buffer size to store the current state.
2440+
state_size = llama_cpp.llama_state_get_size(self._ctx.ctx)
24102441
if self.verbose:
24112442
print(f"Llama.save_state: got state size: {state_size}", file=sys.stderr)
2443+
2444+
# Allocate a ctypes uint8 array (buffer) of the required size.
24122445
llama_state = (ctypes.c_uint8 * int(state_size))()
24132446
if self.verbose:
24142447
print("Llama.save_state: allocated state", file=sys.stderr)
2415-
n_bytes = llama_cpp.llama_copy_state_data(self._ctx.ctx, llama_state)
2448+
2449+
# Copy the raw state data from the internal C context into our Python-managed buffer.
2450+
# Returns the actual number of bytes written (n_bytes).
2451+
n_bytes = llama_cpp.llama_state_get_data(self._ctx.ctx, llama_state, state_size)
24162452
if self.verbose:
24172453
print(f"Llama.save_state: copied llama state: {n_bytes}", file=sys.stderr)
2454+
2455+
# Safety check to prevent buffer overflow issues.
24182456
if int(n_bytes) > int(state_size):
24192457
raise RuntimeError("Failed to copy llama state data")
2420-
llama_state_compact = (ctypes.c_uint8 * int(n_bytes))()
2421-
llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes))
2458+
2459+
# Directly read 'n_bytes' from the buffer's memory address to create the Python bytes object.
2460+
# Significantly reducing memory overhead by avoiding an intermediate array allocation.
2461+
llama_state_bytes = ctypes.string_at(ctypes.addressof(llama_state), int(n_bytes))
24222462
if self.verbose:
24232463
print(
24242464
f"Llama.save_state: saving {n_bytes} bytes of llama state",
24252465
file=sys.stderr,
24262466
)
2467+
2468+
# Create and return the snapshot object.
24272469
return LlamaState(
24282470
scores=self._scores.copy(),
24292471
input_ids=self.input_ids.copy(),
24302472
n_tokens=self.n_tokens,
2431-
llama_state=bytes(llama_state_compact),
2473+
llama_state=llama_state_bytes,
24322474
llama_state_size=n_bytes,
24332475
seed=self._seed,
24342476
)
24352477

24362478
def load_state(self, state: LlamaState) -> None:
2437-
# Only filling in up to `n_tokens` and then zero-ing out the rest
2438-
self.scores[: state.n_tokens, :] = state.scores.copy()
2439-
rest = self.scores[state.n_tokens :, :]
2440-
rest[rest > 0] = 0.0
2479+
# Restore metadata: input tokens, token count, and RNG seed.
24412480
self.input_ids = state.input_ids.copy()
24422481
self.n_tokens = state.n_tokens
24432482
self._seed = state.seed
2483+
# Restore Logits (Scores) handling different memory configurations.
2484+
if self._logits_all:
2485+
# Case A: Full history mode. Restore as many rows as possible.
2486+
available_rows = state.scores.shape[0]
2487+
# Prevent index out of bounds by taking the minimum valid length.
2488+
limit = min(self.n_tokens, available_rows)
2489+
# Restore valid history and clear any remaining "future" slots.
2490+
self.scores[:limit, :] = state.scores[:limit, :]
2491+
self.scores[limit:, :] = 0.0
2492+
else:
2493+
# Case B: Optimized mode (1-row buffer).
2494+
# Only restore the last token's logits if available.
2495+
if state.scores.shape[0] > 0:
2496+
self.scores[0, :] = state.scores[-1, :]
2497+
24442498
state_size = state.llama_state_size
24452499
LLamaStateArrayType = ctypes.c_uint8 * state_size
2500+
# Copy the raw bytes from the Python object into a C-compatible buffer.
24462501
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
24472502

2448-
if llama_cpp.llama_set_state_data(self._ctx.ctx, llama_state) != state_size:
2503+
if llama_cpp.llama_state_set_data(self._ctx.ctx, llama_state, state_size) != state_size:
24492504
raise RuntimeError("Failed to set llama state data")
24502505

24512506
def n_ctx(self) -> int:

0 commit comments

Comments
 (0)