Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
316 changes: 190 additions & 126 deletions src/server/prompt_cache/trie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ impl RadixTrie {
/// present (callers must therefore remember the insertion tokens).
/// Returns `true` if the digest was found and removed.
pub(super) fn remove(&mut self, tokens: &[i32], digest: PromptCacheKeyDigest) -> bool {
remove_from(&mut self.root, tokens, digest, /*is_root=*/ true)
remove_from(&mut self.root, tokens, digest)
}

/// Find the longest token-prefix of `tokens` that reaches a position
Expand Down Expand Up @@ -404,12 +404,20 @@ impl<'a> TrieMatch<'a> {
}
}

// Iterative subtree walk. A recursive DFS would overflow Tokio's ~2 MiB worker
// stacks on an adversarially deep prompt trie (one node per token; see the
// `pop_prefixes` and `Drop` notes), and this runs on the per-request lookup hot
// path via `for_each_candidate`. Use an explicit heap stack instead. Traversal
// order is unspecified, matching the documented contract on `for_each_candidate`.
fn dfs<F: FnMut(DigestAndLen)>(node: &TrieNode, visit: &mut F) {
for d in &node.entries {
visit(*d);
}
for child in node.children.values() {
dfs(child, visit);
let mut stack: Vec<&TrieNode> = vec![node];
while let Some(n) = stack.pop() {
for d in &n.entries {
visit(*d);
}
for child in n.children.values() {
stack.push(child);
}
}
}

Expand Down Expand Up @@ -470,143 +478,199 @@ fn walk_longest<'a>(root: &'a TrieNode, tokens: &[i32]) -> Option<TrieMatch<'a>>
})
}

fn insert_into(node: &mut TrieNode, tokens: &[i32], record: DigestAndLen) {
if tokens.is_empty() {
if !node.entries.iter().any(|e| e.digest == record.digest) {
node.entries.push(record);
node.subtree_count += 1;
} else {
// Replace in place to keep token_len accurate after token
// updates (not typically expected, but safe).
for e in &mut node.entries {
if e.digest == record.digest {
*e = record;
}
// Iterative insert. The former recursion descended one frame per trie node and
// overflowed Tokio's ~2 MiB worker stacks on an adversarially deep prompt trie
// (one node per token), and this runs on the per-request cache-store hot path
// via `PromptCacheStore::insert`. Instead, descend iteratively while recording
// the visited path, perform the single structural mutation (append / split /
// fresh leaf) at the terminal node, then repair `subtree_count` bottom-up over
// the recorded path — exactly mirroring the post-order unwinding of the old
// recursion. Stack usage is O(1); the path lives on the heap.
fn insert_into(root: &mut TrieNode, tokens: &[i32], record: DigestAndLen) {
let mut path: Vec<*mut TrieNode> = Vec::new();
let mut node: *mut TrieNode = root;
let mut tokens: &[i32] = tokens;

loop {
path.push(node);
// SAFETY: `node` points to a valid TrieNode owned by `root`'s tree,
// reachable via the path we just walked. Only one mutable reference is
// live at a time; we never alias `node_ref` with another live borrow.
let node_ref = unsafe { &mut *node };

if tokens.is_empty() {
// Terminal: store the digest here, replacing in place to keep
// token_len accurate if the same digest is re-inserted.
if let Some(existing) = node_ref
.entries
.iter_mut()
.find(|e| e.digest == record.digest)
{
*existing = record;
} else {
node_ref.entries.push(record);
}
break;
}
return;
}

let first_tok = tokens[0];
if let Some(child) = node.children.get_mut(&first_tok) {
let edge_clone = child.edge.clone();
let common = common_prefix_len(&edge_clone, tokens);
if common == edge_clone.len() {
// Full edge consumed: recurse into child.
insert_into(child, &tokens[common..], record);
node.subtree_count = recompute_subtree_count(node);
return;
}
// Edge diverges mid-way: split.
// Create a new intermediate node that keeps the first `common` tokens
// of the old edge as its own incoming edge. Move the old child under
// the new intermediate with its edge truncated to what's left.
let shared: Vec<i32> = edge_clone[..common].to_vec();
let old_remainder: Vec<i32> = edge_clone[common..].to_vec();

// Rebuild the existing child with its shorter edge.
let mut old_child = node.children.remove(&first_tok).expect("child just seen");
let old_first_after_split = old_remainder[0];
old_child.edge = old_remainder;

// New intermediate holds no entries yet.
let mut intermediate = TrieNode {
edge: shared,
entries: Vec::new(),
subtree_count: 0,
children: HashMap::new(),
};
intermediate
.children
.insert(old_first_after_split, old_child);

if tokens.len() == common {
// New record ends exactly at the split point.
intermediate.entries.push(record);
} else {
// New record continues past the split; add a fresh child for
// the remaining tokens.
let new_remainder: Vec<i32> = tokens[common..].to_vec();
let new_first = new_remainder[0];
let new_child = TrieNode {
edge: new_remainder,
entries: vec![record],
subtree_count: 1,
let first_tok = tokens[0];
if let Some(child) = node_ref.children.get_mut(&first_tok) {
let edge_clone = child.edge.clone();
let common = common_prefix_len(&edge_clone, tokens);
if common == edge_clone.len() {
// Full edge consumed: descend into the child.
tokens = &tokens[common..];
node = child.as_mut() as *mut TrieNode;
continue;
}
// Edge diverges mid-way: split (terminal — no further descent).
// Create a new intermediate node that keeps the first `common`
// tokens of the old edge as its own incoming edge. Move the old
// child under the new intermediate with its edge truncated to
// what's left.
let shared: Vec<i32> = edge_clone[..common].to_vec();
let old_remainder: Vec<i32> = edge_clone[common..].to_vec();

// Rebuild the existing child with its shorter edge.
let mut old_child = node_ref
.children
.remove(&first_tok)
.expect("child just seen");
let old_first_after_split = old_remainder[0];
old_child.edge = old_remainder;

// New intermediate holds no entries yet.
let mut intermediate = TrieNode {
edge: shared,
entries: Vec::new(),
subtree_count: 0,
children: HashMap::new(),
};
intermediate.children.insert(new_first, Box::new(new_child));
intermediate
.children
.insert(old_first_after_split, old_child);

if tokens.len() == common {
// New record ends exactly at the split point.
intermediate.entries.push(record);
} else {
// New record continues past the split; add a fresh child for
// the remaining tokens.
let new_remainder: Vec<i32> = tokens[common..].to_vec();
let new_first = new_remainder[0];
let new_child = TrieNode {
edge: new_remainder,
entries: vec![record],
subtree_count: 1,
children: HashMap::new(),
};
intermediate.children.insert(new_first, Box::new(new_child));
}
// The intermediate (and any fresh child) get correct counts here;
// ancestors on `path` are repaired in the bottom-up pass below.
intermediate.subtree_count = recompute_subtree_count(&intermediate);
node_ref.children.insert(first_tok, Box::new(intermediate));
break;
}
intermediate.subtree_count = recompute_subtree_count(&intermediate);
node.children.insert(first_tok, Box::new(intermediate));
node.subtree_count = recompute_subtree_count(node);
return;

// No child for this first token yet: create a fresh leaf with the full
// remaining tokens as its edge.
let leaf = TrieNode {
edge: tokens.to_vec(),
entries: vec![record],
subtree_count: 1,
children: HashMap::new(),
};
node_ref.children.insert(first_tok, Box::new(leaf));
break;
}

// No child for this first token yet: create a fresh leaf with the full
// remaining tokens as its edge.
let edge: Vec<i32> = tokens.to_vec();
let leaf = TrieNode {
edge,
entries: vec![record],
subtree_count: 1,
children: HashMap::new(),
};
node.children.insert(first_tok, Box::new(leaf));
node.subtree_count = recompute_subtree_count(node);
// Repair `subtree_count` bottom-up along the recorded path. Each node's
// children counts are already correct by the time we reach it (deepest
// first), so a local recompute restores the invariant up to the root.
for &n in path.iter().rev() {
// SAFETY: each recorded path node is still live and owned by `root`'s
// tree; the pointers are distinct and visited one at a time.
let n_ref = unsafe { &mut *n };
n_ref.subtree_count = recompute_subtree_count(n_ref);
}
}

fn remove_from(
node: &mut TrieNode,
tokens: &[i32],
digest: PromptCacheKeyDigest,
is_root: bool,
) -> bool {
if tokens.is_empty() {
let before = node.entries.len();
node.entries.retain(|e| e.digest != digest);
let removed = node.entries.len() != before;
if removed {
node.subtree_count = recompute_subtree_count(node);
}
return removed;
}
// Iterative remove. Like `insert_into`, the former recursion descended one frame
// per trie node and could overflow a ~2 MiB Tokio worker stack on an
// adversarially deep prompt trie (this runs on the cache-eviction path via
// `PromptCacheStore`). Descend iteratively, recording `(node, descend-key)` for
// each ancestor; on the way back up prune any child that became an empty leaf,
// merge a now-empty single-child non-root node, and repair `subtree_count` —
// mirroring the post-order unwinding of the old recursion. Stack usage is O(1).
fn remove_from(root: &mut TrieNode, tokens: &[i32], digest: PromptCacheKeyDigest) -> bool {
// Capture the root identity, then navigate exclusively through raw pointers
// below; the `root` reference itself is not touched again, so nothing aliases
// the `&mut` we reborrow from each pointer as we walk.
let root_ptr: *const TrieNode = root;
// `path[i] = (ancestor, key)` where `key` is the child token we descended
// through from `ancestor`. The terminal node (where the digest lives) is
// not recorded here; its `entries`/`subtree_count` are handled inline below.
let mut path: Vec<(*mut TrieNode, i32)> = Vec::new();
let mut node: *mut TrieNode = root;
let mut tokens: &[i32] = tokens;

let first_tok = tokens[0];
let mut take_child = false;
let removed = if let Some(child) = node.children.get_mut(&first_tok) {
let edge_len = child.edge.len();
if tokens.len() < edge_len || tokens[..edge_len] != child.edge[..] {
false
} else {
let rest = &tokens[edge_len..];
let removed = remove_from(child, rest, digest, false);
if removed && child.entries.is_empty() && child.children.is_empty() {
// Leaf with no entries — drop it.
take_child = true;
loop {
// SAFETY: `node` points to a live TrieNode owned by `root`'s tree along
// the validated path; only one mutable reference is live at a time.
let node_ref = unsafe { &mut *node };
if tokens.is_empty() {
let before = node_ref.entries.len();
node_ref.entries.retain(|e| e.digest != digest);
if node_ref.entries.len() == before {
// Digest absent — nothing changed anywhere, so no path repair.
return false;
}
removed
node_ref.subtree_count = recompute_subtree_count(node_ref);
break;
}
} else {
false
};

if take_child {
node.children.remove(&first_tok);
}

// After removing, compress the chain: if `node` is non-root, has no
// entries, and has exactly one child, we can merge. But the caller owns
// the current node, so we can only merge children of this node into
// their grandchildren from here. We handle this by opportunistically
// merging each single-child chain reachable from here.
if !is_root && removed {
merge_only_child_if_empty(node);
let first_tok = tokens[0];
let edge_len = match node_ref.children.get(&first_tok) {
Some(child) => {
let elen = child.edge.len();
if tokens.len() < elen || tokens[..elen] != child.edge[..] {
// Edge label doesn't match — digest can't be present.
return false;
}
elen
}
None => return false,
};
path.push((node, first_tok));
tokens = &tokens[edge_len..];
node = node_ref
.children
.get_mut(&first_tok)
.expect("edge validated above")
.as_mut() as *mut TrieNode;
}

if removed {
node.subtree_count = recompute_subtree_count(node);
// Walk back up the recorded ancestors (deepest first). For each: drop the
// child we descended through if it became an empty leaf, then compress this
// node if it is now an empty single-child non-root node, then repair count.
for &(pnode, key) in path.iter().rev() {
// SAFETY: each ancestor is still live and owned by `root`'s tree; the
// pointers are distinct and visited one at a time.
let pref = unsafe { &mut *pnode };
let child_is_empty_leaf = match pref.children.get(&key) {
Some(child) => child.entries.is_empty() && child.children.is_empty(),
None => false,
};
if child_is_empty_leaf {
pref.children.remove(&key);
}
if !std::ptr::eq(pnode as *const TrieNode, root_ptr) {
merge_only_child_if_empty(pref);
}
pref.subtree_count = recompute_subtree_count(pref);
}
removed
true
}

fn merge_only_child_if_empty(node: &mut TrieNode) {
Expand Down
Loading