11import abc
22import collections
33
4- from typing import Any , Dict , List , Tuple
4+ from typing import Any , DefaultDict , Dict , List , Literal , Optional , Tuple
55
66import numpy as np
77import numpy .typing as npt
@@ -17,102 +17,293 @@ def __call__(
1717
1818class LlamaNGramMapDecoding (LlamaDraftModel ):
1919 """
20- Ultra-fast speculative decoder based on hash inverted index and incremental updates.
21- O(1) time complexity, aligned with llama.cpp's underlying ngram-map algorithm.
20+ Fast model-free speculative decoder based on prompt n-gram lookup.
21+
22+ It supports two modes:
23+
24+ - "k":
25+ Key-only mode. Stores n-gram key -> history positions.
26+ This is memory-efficient and similar to llama.cpp's ngram-map-k behavior.
27+
28+ - "k4v":
29+ Key-to-value mode. Stores n-gram key -> continuation tokens.
30+ This uses more memory, but can return cached continuations directly.
31+
32+ This class does not use a draft model. It only speculates from already verified
33+ token history. Therefore, rejected tokens are handled naturally when the next
34+ `input_ids` is passed in.
35+
36+ Aligned with llama.cpp's underlying ngram-map k/k4v algorithm.
2237 """
2338
24- def __init__ (self , ngram_size : int = 3 , num_pred_tokens : int = 10 ):
39+ def __init__ (
40+ self ,
41+ ngram_size : int = 3 ,
42+ num_pred_tokens : int = 10 ,
43+ mode : Literal ["k" , "k4v" ] = "k" ,
44+ min_hits : int = 2 ,
45+ max_entries_per_key : Optional [int ] = None ,
46+ sync_check_tokens : int = 16 ,
47+ ) -> None :
2548 """
26- Initializes the N-Gram Map speculative decoder.
27-
2849 Args:
29- ngram_size (int): The length of the token sequence used as the search key.
30- Larger values provide strictly accurate context matching but may result
31- in fewer cache hits. Defaults to 3.
32- num_pred_tokens (int): The maximum number of future tokens to draft (predict)
33- and return once a match is found in the history. Defaults to 10.
50+ ngram_size:
51+ Number of tokens used as the lookup key.
52+
53+ num_pred_tokens:
54+ Maximum number of draft tokens to return.
55+
56+ mode:
57+ "k" stores only matched positions.
58+ "k4v" stores matched continuation values directly.
59+
60+ min_hits:
61+ Minimum number of historical matches required before returning a draft.
62+ Use 1 for maximum recall. Use >1 to reduce low-confidence drafts.
63+
64+ max_entries_per_key:
65+ Optional memory cap per n-gram key.
66+ When set, only the most recent entries are kept.
67+ For k4v mode, setting max_entries_per_key is strongly recommended.
68+
69+ sync_check_tokens:
70+ Number of trailing tokens used to verify whether the new input is an
71+ incremental append of the previous input. This avoids expensive full
72+ prefix comparison while still detecting most rollback/prompt-switch cases.
3473 """
35- self .ngram_size = ngram_size
36- self .num_pred_tokens = num_pred_tokens
74+ if ngram_size <= 0 :
75+ raise ValueError ("ngram_size must be greater than 0" )
76+ if num_pred_tokens <= 0 :
77+ raise ValueError ("num_pred_tokens must be greater than 0" )
78+ if min_hits <= 0 :
79+ raise ValueError ("min_hits must be greater than 0" )
80+ if max_entries_per_key is not None and max_entries_per_key <= 0 :
81+ raise ValueError ("max_entries_per_key must be None or greater than 0" )
82+ if sync_check_tokens <= 0 :
83+ raise ValueError ("sync_check_tokens must be greater than 0" )
84+
85+ mode = mode .lower ()
86+ if mode not in ("k" , "k4v" ):
87+ raise ValueError ("mode must be either 'k' or 'k4v'" )
88+
89+ self .ngram_size = int (ngram_size )
90+ self .num_pred_tokens = int (num_pred_tokens )
91+ self .mode = mode
92+ self .min_hits = int (min_hits )
93+ self .sync_check_tokens = int (sync_check_tokens )
94+
95+ if mode == "k4v" and max_entries_per_key is None :
96+ max_entries_per_key = 8
97+ self .max_entries_per_key = max_entries_per_key
3798
38- # Core state cache
39- # Mapping format: (token_1, ..., token_N) -> [index_1, index_2, ...]
40- self ._ngram_map : Dict [Tuple [int , ...], List [int ]] = collections .defaultdict (list )
4199 self ._history : List [int ] = []
42100
43- def _update_cache (self , input_ids : npt .NDArray [np .intc ]) -> None :
101+ # In "k" mode:
102+ # key -> [position, position, ...]
103+ self ._map_k : DefaultDict [Tuple [int , ...], List [int ]] = collections .defaultdict (list )
104+
105+ # In "k4v" mode:
106+ # key -> {position: continuation}
107+ #
108+ # A dict is used so that recent entries can be refreshed when more continuation
109+ # tokens become available.
110+ self ._map_k4v : DefaultDict [
111+ Tuple [int , ...], Dict [int , Tuple [int , ...]]
112+ ] = collections .defaultdict (dict )
113+
114+ self ._closed = False
115+ self ._last_draft_len = 0
116+
117+ def clear (self ) -> None :
44118 """
45- Smart state synchronization and incremental build (Extreme O(1) optimization) .
119+ Clear token history and indexes .
46120
47- Args:
48- input_ids (npt.NDArray[np.intc]): The complete sequence of current token IDs
49- generated or processed so far.
121+ Use this when starting a completely unrelated generation while keeping the
122+ decoder instance reusable.
123+ """
124+ self ._history .clear ()
125+ self ._map_k .clear ()
126+ self ._map_k4v .clear ()
127+ self ._last_draft_len = 0
128+
129+ def close (self ) -> None :
130+ """
131+ Release internal memory.
132+
133+ This class does not own native memory, but clearing large Python containers
134+ explicitly is still useful for long-running applications.
135+ """
136+ self .clear ()
137+ self ._closed = True
138+
139+ def __del__ (self ) -> None :
140+ # Best-effort cleanup. Program correctness must not depend on __del__.
141+ try :
142+ self .close ()
143+ except Exception :
144+ pass
145+
146+ def accept (self , n_accepted : int ) -> None :
50147 """
51- new_len = len (input_ids )
148+ Notify how many draft tokens were accepted by the target model.
149+
150+ This implementation does not need to update internal state here, because the
151+ next call receives the verified token history through `input_ids`.
152+
153+ The method is kept for API symmetry and future extensions, such as acceptance
154+ statistics, adaptive reset, or low-acceptance fallback.
155+ """
156+ return
157+
158+ def _sync_and_index (self , input_ids : npt .NDArray [np .intc ]) -> None :
159+ """
160+ Synchronize internal history with input_ids and update the n-gram index.
161+
162+ The index intentionally stores only n-grams that have at least one continuation
163+ token. This prevents the current tail n-gram from matching itself and returning
164+ an empty draft.
165+ """
166+ if self ._closed :
167+ raise RuntimeError ("LlamaNGramMapDecoding is closed" )
168+
169+ tokens = np .asarray (input_ids , dtype = np .intc ).reshape (- 1 ).tolist ()
170+
52171 old_len = len (self ._history )
172+ new_len = len (tokens )
173+
174+ if new_len == 0 :
175+ self .clear ()
176+ return
177+
178+ # Fast path: identical input, no update needed.
179+ if new_len == old_len :
180+ if self ._history == tokens :
181+ return
182+
183+ # Incremental append path.
184+ is_append = False
185+ if old_len > 0 and new_len > old_len :
186+ check_len = min (old_len , max (self .ngram_size , self .sync_check_tokens ))
187+ is_append = self ._history [old_len - check_len : old_len ] == tokens [
188+ old_len - check_len : old_len
189+ ]
190+
191+ if is_append :
192+ # Append only new tokens.
193+ self ._history .extend (tokens [old_len :])
194+
195+ if self .mode == "k" :
196+ # Only newly-valid keys need to be added.
197+ start = max (0 , old_len - self .ngram_size )
198+ else :
199+ # K4V must also refresh recent keys because their continuation values
200+ # can grow as new tokens are appended.
201+ start = max (0 , old_len - self .ngram_size - self .num_pred_tokens + 1 )
202+ else :
203+ # Rollback, prompt switch, truncation, or unsafe mutation.
204+ self .clear ()
205+ self ._history .extend (tokens )
206+ start = 0
207+
208+ # Only index keys that have at least one token after the key.
209+ # Valid pos satisfies:
210+ # pos + ngram_size < len(history)
211+ end = max (0 , len (self ._history ) - self .ngram_size )
212+
213+ if start >= end :
214+ return
215+
216+ if self .mode == "k" :
217+ for pos in range (start , end ):
218+ key = tuple (self ._history [pos : pos + self .ngram_size ])
219+ bucket = self ._map_k [key ]
220+
221+ if not bucket or bucket [- 1 ] != pos :
222+ bucket .append (pos )
223+
224+ if (
225+ self .max_entries_per_key is not None
226+ and len (bucket ) > self .max_entries_per_key
227+ ):
228+ del bucket [: len (bucket ) - self .max_entries_per_key ]
53229
54- # Check if it's a perfect incremental append (verify if the previous token matches)
55- is_incremental = False
56- if new_len > old_len and old_len > 0 :
57- if self ._history [- 1 ] == input_ids [old_len - 1 ]:
58- is_incremental = True
59-
60- if is_incremental :
61- # Only extract, convert, and append new tokens.
62- # Never copy or touch the entire historical array!
63- new_tokens = input_ids [old_len :].tolist ()
64- self ._history .extend (new_tokens )
65- start_idx = max (0 , old_len - self .ngram_size )
66230 else :
67- # Rollback occurred (wrong prediction) or a completely new Prompt. Trigger full rebuild.
68- self ._ngram_map .clear ()
69- self ._history = input_ids .tolist ()
70- start_idx = 0
231+ for pos in range (start , end ):
232+ key_start = pos
233+ value_start = pos + self .ngram_size
234+ value_end = min (value_start + self .num_pred_tokens , len (self ._history ))
235+
236+ if value_start >= value_end :
237+ continue
238+
239+ key = tuple (self ._history [key_start :value_start ])
240+ value = tuple (self ._history [value_start :value_end ])
71241
72- # Build/update the hash inverted index
73- for i in range (start_idx , new_len - self .ngram_size ):
74- key = tuple (self ._history [i : i + self .ngram_size ])
75- self ._ngram_map [key ].append (i )
242+ bucket = self ._map_k4v [key ]
243+ bucket [pos ] = value
244+
245+ if (
246+ self .max_entries_per_key is not None
247+ and len (bucket ) > self .max_entries_per_key
248+ ):
249+ # Keep the most recent positions.
250+ for old_pos in sorted (bucket )[: len (bucket ) - self .max_entries_per_key ]:
251+ del bucket [old_pos ]
76252
77253 def __call__ (
78254 self , input_ids : npt .NDArray [np .intc ], / , ** kwargs : Any
79255 ) -> npt .NDArray [np .intc ]:
80256 """
81- Generates draft tokens based on historical N-Gram frequency .
257+ Generate draft tokens from verified token history .
82258
83259 Args:
84- input_ids (npt.NDArray[np.intc]): The current sequence of token IDs.
85- **kwargs: Additional generation arguments (ignored in this implementation) .
260+ input_ids:
261+ Complete verified token sequence so far .
86262
87263 Returns:
88- npt.NDArray [np.intc]: An array of predicted draft tokens. Returns an empty
89- array if no matching context is found.
264+ np.ndarray [np.intc]:
265+ Predicted draft tokens. Empty array means no reliable match was found.
90266 """
91- # 1. Ultra-fast state synchronization
92- self ._update_cache (input_ids )
267+ _ = kwargs
268+
269+ self ._sync_and_index (input_ids )
270+ self ._last_draft_len = 0
93271
94- # 2. Cannot speculate if the history is too short
95272 if len (self ._history ) < self .ngram_size :
96273 return np .array ([], dtype = np .intc )
97274
98- # 3. Extract the Search Key (the last N tokens)
99- search_key = tuple (self ._history [- self .ngram_size :])
275+ search_key = tuple (self ._history [- self .ngram_size :])
100276
101- # 4. O(1) instant lookup
102- match_indices = self ._ngram_map .get (search_key )
277+ if self .mode == "k" :
278+ positions = self ._map_k .get (search_key )
279+ if not positions or len (positions ) < self .min_hits :
280+ return np .array ([], dtype = np .intc )
103281
104- if not match_indices :
105- return np .array ([], dtype = np .intc )
282+ # Use the latest valid match with an available continuation.
283+ draft : List [int ] = []
284+ for pos in reversed (positions ):
285+ start = pos + self .ngram_size
286+ if start < len (self ._history ):
287+ end = min (start + self .num_pred_tokens , len (self ._history ))
288+ draft = self ._history [start :end ]
289+ break
290+
291+ else :
292+ values = self ._map_k4v .get (search_key )
293+ if not values or len (values ) < self .min_hits :
294+ return np .array ([], dtype = np .intc )
106295
107- # 5. Get the context of the last match and extract draft tokens
108- best_match_idx = match_indices [- 1 ]
109- draft_start = best_match_idx + self .ngram_size
110- draft_end = min (draft_start + self .num_pred_tokens , len (self ._history ))
296+ # Use the continuation from the latest historical position.
297+ latest_pos = max (values )
298+ draft = list (values [latest_pos ])
111299
112- return np .array (self ._history [draft_start :draft_end ], dtype = np .intc )
300+ self ._last_draft_len = len (draft )
301+ return np .asarray (draft , dtype = np .intc )
113302
114303
115304# Legacy Numpy sliding window implementation
305+ # Fast in some cases, but may degrade output quality.
306+ # Not recommended for production.
116307class LlamaPromptLookupDecoding (LlamaDraftModel ):
117308 """
118309 Stateless speculative decoding based on Numpy sliding window
0 commit comments