Skip to content

Commit 78fa55b

Browse files
committed
feat(speculative): upgrade ngram map decoder with k/k4v modes
Enhance `LlamaNGramMapDecoding` to align with the upstream llama.cpp ngram-map algorithm, offering better memory management and draft quality. - Introduce `mode` selection ("k" and "k4v"): "k" stores only historical positions for memory efficiency, while "k4v" caches continuation values directly for faster lookups. - Add `min_hits` threshold to filter out low-confidence drafts. - Implement `max_entries_per_key` to cap dictionary growth and prevent memory bloat during long-context generations. - Improve state synchronization (`_sync_and_index`) using `sync_check_tokens` to safely verify incremental history appends. - Add explicit lifecycle management methods (`clear`, `close`, `accept`) for better API symmetry and resource cleanup. Signed-off-by: JamePeng <jame_peng@sina.com>
1 parent b2f09bb commit 78fa55b

1 file changed

Lines changed: 252 additions & 61 deletions

File tree

llama_cpp/llama_speculative.py

Lines changed: 252 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import abc
22
import collections
33

4-
from typing import Any, Dict, List, Tuple
4+
from typing import Any, DefaultDict, Dict, List, Literal, Optional, Tuple
55

66
import numpy as np
77
import numpy.typing as npt
@@ -17,102 +17,293 @@ def __call__(
1717

1818
class LlamaNGramMapDecoding(LlamaDraftModel):
1919
"""
20-
Ultra-fast speculative decoder based on hash inverted index and incremental updates.
21-
O(1) time complexity, aligned with llama.cpp's underlying ngram-map algorithm.
20+
Fast model-free speculative decoder based on prompt n-gram lookup.
21+
22+
It supports two modes:
23+
24+
- "k":
25+
Key-only mode. Stores n-gram key -> history positions.
26+
This is memory-efficient and similar to llama.cpp's ngram-map-k behavior.
27+
28+
- "k4v":
29+
Key-to-value mode. Stores n-gram key -> continuation tokens.
30+
This uses more memory, but can return cached continuations directly.
31+
32+
This class does not use a draft model. It only speculates from already verified
33+
token history. Therefore, rejected tokens are handled naturally when the next
34+
`input_ids` is passed in.
35+
36+
Aligned with llama.cpp's underlying ngram-map k/k4v algorithm.
2237
"""
2338

24-
def __init__(self, ngram_size: int = 3, num_pred_tokens: int = 10):
39+
def __init__(
40+
self,
41+
ngram_size: int = 3,
42+
num_pred_tokens: int = 10,
43+
mode: Literal["k", "k4v"] = "k",
44+
min_hits: int = 2,
45+
max_entries_per_key: Optional[int] = None,
46+
sync_check_tokens: int = 16,
47+
) -> None:
2548
"""
26-
Initializes the N-Gram Map speculative decoder.
27-
2849
Args:
29-
ngram_size (int): The length of the token sequence used as the search key.
30-
Larger values provide strictly accurate context matching but may result
31-
in fewer cache hits. Defaults to 3.
32-
num_pred_tokens (int): The maximum number of future tokens to draft (predict)
33-
and return once a match is found in the history. Defaults to 10.
50+
ngram_size:
51+
Number of tokens used as the lookup key.
52+
53+
num_pred_tokens:
54+
Maximum number of draft tokens to return.
55+
56+
mode:
57+
"k" stores only matched positions.
58+
"k4v" stores matched continuation values directly.
59+
60+
min_hits:
61+
Minimum number of historical matches required before returning a draft.
62+
Use 1 for maximum recall. Use >1 to reduce low-confidence drafts.
63+
64+
max_entries_per_key:
65+
Optional memory cap per n-gram key.
66+
When set, only the most recent entries are kept.
67+
For k4v mode, setting max_entries_per_key is strongly recommended.
68+
69+
sync_check_tokens:
70+
Number of trailing tokens used to verify whether the new input is an
71+
incremental append of the previous input. This avoids expensive full
72+
prefix comparison while still detecting most rollback/prompt-switch cases.
3473
"""
35-
self.ngram_size = ngram_size
36-
self.num_pred_tokens = num_pred_tokens
74+
if ngram_size <= 0:
75+
raise ValueError("ngram_size must be greater than 0")
76+
if num_pred_tokens <= 0:
77+
raise ValueError("num_pred_tokens must be greater than 0")
78+
if min_hits <= 0:
79+
raise ValueError("min_hits must be greater than 0")
80+
if max_entries_per_key is not None and max_entries_per_key <= 0:
81+
raise ValueError("max_entries_per_key must be None or greater than 0")
82+
if sync_check_tokens <= 0:
83+
raise ValueError("sync_check_tokens must be greater than 0")
84+
85+
mode = mode.lower()
86+
if mode not in ("k", "k4v"):
87+
raise ValueError("mode must be either 'k' or 'k4v'")
88+
89+
self.ngram_size = int(ngram_size)
90+
self.num_pred_tokens = int(num_pred_tokens)
91+
self.mode = mode
92+
self.min_hits = int(min_hits)
93+
self.sync_check_tokens = int(sync_check_tokens)
94+
95+
if mode == "k4v" and max_entries_per_key is None:
96+
max_entries_per_key = 8
97+
self.max_entries_per_key = max_entries_per_key
3798

38-
# Core state cache
39-
# Mapping format: (token_1, ..., token_N) -> [index_1, index_2, ...]
40-
self._ngram_map: Dict[Tuple[int, ...], List[int]] = collections.defaultdict(list)
4199
self._history: List[int] = []
42100

43-
def _update_cache(self, input_ids: npt.NDArray[np.intc]) -> None:
101+
# In "k" mode:
102+
# key -> [position, position, ...]
103+
self._map_k: DefaultDict[Tuple[int, ...], List[int]] = collections.defaultdict(list)
104+
105+
# In "k4v" mode:
106+
# key -> {position: continuation}
107+
#
108+
# A dict is used so that recent entries can be refreshed when more continuation
109+
# tokens become available.
110+
self._map_k4v: DefaultDict[
111+
Tuple[int, ...], Dict[int, Tuple[int, ...]]
112+
] = collections.defaultdict(dict)
113+
114+
self._closed = False
115+
self._last_draft_len = 0
116+
117+
def clear(self) -> None:
44118
"""
45-
Smart state synchronization and incremental build (Extreme O(1) optimization).
119+
Clear token history and indexes.
46120
47-
Args:
48-
input_ids (npt.NDArray[np.intc]): The complete sequence of current token IDs
49-
generated or processed so far.
121+
Use this when starting a completely unrelated generation while keeping the
122+
decoder instance reusable.
123+
"""
124+
self._history.clear()
125+
self._map_k.clear()
126+
self._map_k4v.clear()
127+
self._last_draft_len = 0
128+
129+
def close(self) -> None:
130+
"""
131+
Release internal memory.
132+
133+
This class does not own native memory, but clearing large Python containers
134+
explicitly is still useful for long-running applications.
135+
"""
136+
self.clear()
137+
self._closed = True
138+
139+
def __del__(self) -> None:
140+
# Best-effort cleanup. Program correctness must not depend on __del__.
141+
try:
142+
self.close()
143+
except Exception:
144+
pass
145+
146+
def accept(self, n_accepted: int) -> None:
50147
"""
51-
new_len = len(input_ids)
148+
Notify how many draft tokens were accepted by the target model.
149+
150+
This implementation does not need to update internal state here, because the
151+
next call receives the verified token history through `input_ids`.
152+
153+
The method is kept for API symmetry and future extensions, such as acceptance
154+
statistics, adaptive reset, or low-acceptance fallback.
155+
"""
156+
return
157+
158+
def _sync_and_index(self, input_ids: npt.NDArray[np.intc]) -> None:
159+
"""
160+
Synchronize internal history with input_ids and update the n-gram index.
161+
162+
The index intentionally stores only n-grams that have at least one continuation
163+
token. This prevents the current tail n-gram from matching itself and returning
164+
an empty draft.
165+
"""
166+
if self._closed:
167+
raise RuntimeError("LlamaNGramMapDecoding is closed")
168+
169+
tokens = np.asarray(input_ids, dtype=np.intc).reshape(-1).tolist()
170+
52171
old_len = len(self._history)
172+
new_len = len(tokens)
173+
174+
if new_len == 0:
175+
self.clear()
176+
return
177+
178+
# Fast path: identical input, no update needed.
179+
if new_len == old_len:
180+
if self._history == tokens:
181+
return
182+
183+
# Incremental append path.
184+
is_append = False
185+
if old_len > 0 and new_len > old_len:
186+
check_len = min(old_len, max(self.ngram_size, self.sync_check_tokens))
187+
is_append = self._history[old_len - check_len : old_len] == tokens[
188+
old_len - check_len : old_len
189+
]
190+
191+
if is_append:
192+
# Append only new tokens.
193+
self._history.extend(tokens[old_len:])
194+
195+
if self.mode == "k":
196+
# Only newly-valid keys need to be added.
197+
start = max(0, old_len - self.ngram_size)
198+
else:
199+
# K4V must also refresh recent keys because their continuation values
200+
# can grow as new tokens are appended.
201+
start = max(0, old_len - self.ngram_size - self.num_pred_tokens + 1)
202+
else:
203+
# Rollback, prompt switch, truncation, or unsafe mutation.
204+
self.clear()
205+
self._history.extend(tokens)
206+
start = 0
207+
208+
# Only index keys that have at least one token after the key.
209+
# Valid pos satisfies:
210+
# pos + ngram_size < len(history)
211+
end = max(0, len(self._history) - self.ngram_size)
212+
213+
if start >= end:
214+
return
215+
216+
if self.mode == "k":
217+
for pos in range(start, end):
218+
key = tuple(self._history[pos : pos + self.ngram_size])
219+
bucket = self._map_k[key]
220+
221+
if not bucket or bucket[-1] != pos:
222+
bucket.append(pos)
223+
224+
if (
225+
self.max_entries_per_key is not None
226+
and len(bucket) > self.max_entries_per_key
227+
):
228+
del bucket[: len(bucket) - self.max_entries_per_key]
53229

54-
# Check if it's a perfect incremental append (verify if the previous token matches)
55-
is_incremental = False
56-
if new_len > old_len and old_len > 0:
57-
if self._history[-1] == input_ids[old_len - 1]:
58-
is_incremental = True
59-
60-
if is_incremental:
61-
# Only extract, convert, and append new tokens.
62-
# Never copy or touch the entire historical array!
63-
new_tokens = input_ids[old_len:].tolist()
64-
self._history.extend(new_tokens)
65-
start_idx = max(0, old_len - self.ngram_size)
66230
else:
67-
# Rollback occurred (wrong prediction) or a completely new Prompt. Trigger full rebuild.
68-
self._ngram_map.clear()
69-
self._history = input_ids.tolist()
70-
start_idx = 0
231+
for pos in range(start, end):
232+
key_start = pos
233+
value_start = pos + self.ngram_size
234+
value_end = min(value_start + self.num_pred_tokens, len(self._history))
235+
236+
if value_start >= value_end:
237+
continue
238+
239+
key = tuple(self._history[key_start:value_start])
240+
value = tuple(self._history[value_start:value_end])
71241

72-
# Build/update the hash inverted index
73-
for i in range(start_idx, new_len - self.ngram_size):
74-
key = tuple(self._history[i : i + self.ngram_size])
75-
self._ngram_map[key].append(i)
242+
bucket = self._map_k4v[key]
243+
bucket[pos] = value
244+
245+
if (
246+
self.max_entries_per_key is not None
247+
and len(bucket) > self.max_entries_per_key
248+
):
249+
# Keep the most recent positions.
250+
for old_pos in sorted(bucket)[: len(bucket) - self.max_entries_per_key]:
251+
del bucket[old_pos]
76252

77253
def __call__(
78254
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
79255
) -> npt.NDArray[np.intc]:
80256
"""
81-
Generates draft tokens based on historical N-Gram frequency.
257+
Generate draft tokens from verified token history.
82258
83259
Args:
84-
input_ids (npt.NDArray[np.intc]): The current sequence of token IDs.
85-
**kwargs: Additional generation arguments (ignored in this implementation).
260+
input_ids:
261+
Complete verified token sequence so far.
86262
87263
Returns:
88-
npt.NDArray[np.intc]: An array of predicted draft tokens. Returns an empty
89-
array if no matching context is found.
264+
np.ndarray[np.intc]:
265+
Predicted draft tokens. Empty array means no reliable match was found.
90266
"""
91-
# 1. Ultra-fast state synchronization
92-
self._update_cache(input_ids)
267+
_ = kwargs
268+
269+
self._sync_and_index(input_ids)
270+
self._last_draft_len = 0
93271

94-
# 2. Cannot speculate if the history is too short
95272
if len(self._history) < self.ngram_size:
96273
return np.array([], dtype=np.intc)
97274

98-
# 3. Extract the Search Key (the last N tokens)
99-
search_key = tuple(self._history[-self.ngram_size:])
275+
search_key = tuple(self._history[-self.ngram_size :])
100276

101-
# 4. O(1) instant lookup
102-
match_indices = self._ngram_map.get(search_key)
277+
if self.mode == "k":
278+
positions = self._map_k.get(search_key)
279+
if not positions or len(positions) < self.min_hits:
280+
return np.array([], dtype=np.intc)
103281

104-
if not match_indices:
105-
return np.array([], dtype=np.intc)
282+
# Use the latest valid match with an available continuation.
283+
draft: List[int] = []
284+
for pos in reversed(positions):
285+
start = pos + self.ngram_size
286+
if start < len(self._history):
287+
end = min(start + self.num_pred_tokens, len(self._history))
288+
draft = self._history[start:end]
289+
break
290+
291+
else:
292+
values = self._map_k4v.get(search_key)
293+
if not values or len(values) < self.min_hits:
294+
return np.array([], dtype=np.intc)
106295

107-
# 5. Get the context of the last match and extract draft tokens
108-
best_match_idx = match_indices[-1]
109-
draft_start = best_match_idx + self.ngram_size
110-
draft_end = min(draft_start + self.num_pred_tokens, len(self._history))
296+
# Use the continuation from the latest historical position.
297+
latest_pos = max(values)
298+
draft = list(values[latest_pos])
111299

112-
return np.array(self._history[draft_start:draft_end], dtype=np.intc)
300+
self._last_draft_len = len(draft)
301+
return np.asarray(draft, dtype=np.intc)
113302

114303

115304
# Legacy Numpy sliding window implementation
305+
# Fast in some cases, but may degrade output quality.
306+
# Not recommended for production.
116307
class LlamaPromptLookupDecoding(LlamaDraftModel):
117308
"""
118309
Stateless speculative decoding based on Numpy sliding window

0 commit comments

Comments
 (0)