Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 76 additions & 6 deletions fastdeploy/cache_manager/prefix_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,8 +816,8 @@ def request_match_blocks(self, task: Request, block_size, *args):
# 2. prepare cpu cache: allocate gpu cache for matched cpu blocks, wait for data transfer to complete
gpu_recv_block_ids = []
match_cpu_blocks_num = len(match_cpu_block_ids)
if self.can_allocate_gpu_blocks(num_blocks=match_cpu_blocks_num):
if match_cpu_blocks_num > 0:
if match_cpu_blocks_num > 0:
if self.can_allocate_gpu_blocks(num_blocks=match_cpu_blocks_num):
logger.debug(
f"request_match_blocks: req_id {req_id}, allocate {match_cpu_blocks_num} block to receive cpu cache"
)
Expand All @@ -833,10 +833,10 @@ def request_match_blocks(self, task: Request, block_size, *args):
)
cost_time = time.time() - start_time
metrics["cpu_cache_prepare_time"] = cost_time
else:
raise Exception(
"request_match_blocks: Not enough GPU memory to allocate cache for matched CPU Cache"
)
else:
raise Exception(
"request_match_blocks: Not enough GPU memory to allocate cache for matched CPU Cache"
)

# 3. match and prefetch cache from storage
match_token_num = gpu_match_token_num + cpu_match_token_num
Expand Down Expand Up @@ -1749,6 +1749,76 @@ def mm_match_block(self, request, block_size):
cpu_match_token_num,
)

def pre_match_block_on_gpu(self, request):
"""
Pre-match request tokens against cached GPU blocks in the radix tree.

This method performs a prefix matching operation to find the longest sequence
of tokens that already exist in GPU cache blocks. It traverses the radix tree
from the root, computing hash values for each block-sized chunk of tokens
and checking if corresponding nodes exist with GPU-resident data.

Args:
request: The inference request object containing prompt_token_ids and
output_token_ids to be matched against the cache.

Returns:
tuple: A tuple containing:
- match_token_num (int): The total number of tokens that were
successfully matched in GPU-resident blocks.
- last_node (BlockNode): The last matched node in the radix tree,
which represents the deepest point of prefix cache hit.

Note:
- Only blocks with `has_in_gpu=True` are considered as valid matches.
- The matching stops at the first mismatch or when a block is not in GPU.
- This is a read-only operation that does not modify the radix tree
or LRU data structures.
"""
if isinstance(request.prompt_token_ids, np.ndarray):
prompt_token_ids = request.prompt_token_ids.tolist()
else:
prompt_token_ids = request.prompt_token_ids
input_ids = prompt_token_ids + request.output_token_ids
total_token_num = len(input_ids)

last_node = self.radix_tree_root
match_token_num = 0
mm_idx = 0
prefix_block_key = []
block_size = self.config.cache_config.block_size
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre_match_block_on_gpu 内部直接读取 self.config.cache_config.block_size,同时在调度侧已经显式拿到了 block_size 并据此计算 need_block_num。为了避免未来 block_size 来源不一致导致匹配/预估不一致,建议让该方法接收 block_size 参数(与现有 mm_match_block(..., block_size) / request_match_blocks(..., block_size) 保持一致),并在内部使用传入值。

Suggested change
block_size = self.config.cache_config.block_size
# Prefer block_size from request (set by scheduler) to avoid mismatch
# with need_block_num calculation; fall back to config for backward compatibility.
block_size = getattr(request, "block_size", None) or self.config.cache_config.block_size

Copilot uses AI. Check for mistakes.

with self.cache_status_lock:
while match_token_num < total_token_num:
token_block = input_ids[match_token_num : match_token_num + block_size]
if len(token_block) != block_size:
break

mm_idx, extra_keys = self.get_block_hash_extra_keys(
request=request,
start_idx=match_token_num,
end_idx=match_token_num + block_size,
mm_idx=mm_idx,
)
prefix_block_key.extend(extra_keys)
hash_value = get_hash_str(token_block, prefix_block_key)
prefix_block_key = [hash_value]

if hash_value not in last_node.children:
break

child = last_node.children[hash_value]
if not child.has_in_gpu:
break
match_token_num += block_size
last_node = child

logger.info(f"pre_match_block_on_gpu: req_id {request.request_id}, match_token_num {match_token_num}")
return (
match_token_num,
last_node,
)
Comment on lines +1752 to +1820
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre_match_block_on_gpu 是新引入的缓存匹配入口,但目前在仓库测试中没有直接覆盖它的行为(例如:仅 GPU-resident 节点可匹配、遇到首个非 GPU 节点即停止、以及多模态 extra_keys 参与 hash 的一致性)。考虑到该方法会影响调度决策,建议在 tests/cache_manager/test_prefix_cache_manager.pytests/v1/cache_manager/test_prefix_cache.py 增加针对该方法的单测,确保和现有 mm_match_block 的 hash/遍历语义一致并避免后续重构引入偏差。

Copilot uses AI. Check for mistakes.

def match_block(self, req_id, input_ids, block_size):
"""
Args:
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def __init__(
self.status = RequestStatus.WAITING
self.task_type = RequestType.PREFILL
self.has_been_preempted_before = False
self.has_update_mm_hashes = False
self.idx = None
self.need_prefill_tokens = self.prompt_token_ids_len
self.audio_output_token_ids = []
Expand Down
49 changes: 33 additions & 16 deletions fastdeploy/engine/sched/resource_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re
self.can_relax_prefill_strategy = False
return can_schedule

def _get_can_schedule_prefill_threshold_block(self, request, num_chunk_new_block):
def _get_can_schedule_prefill_threshold_block(self, num_chunk_new_block):
if self.can_relax_prefill_strategy:
can_schedule_block_num_threshold = num_chunk_new_block
else:
Expand All @@ -412,7 +412,7 @@ def _get_can_schedule_prefill_threshold_block(self, request, num_chunk_new_block
return can_schedule_block_num_threshold

def _update_mm_hashes(self, request):
if request.multimodal_inputs is None:
if request.multimodal_inputs is None or request.has_update_mm_hashes:
return

inputs = request.multimodal_inputs
Expand Down Expand Up @@ -450,6 +450,8 @@ def _update_mm_hashes(self, request):
inputs["mm_positions"] = []
inputs["mm_hashes"] = []

request.has_update_mm_hashes = True

def _is_mm_request(self, request):
inputs = request.multimodal_inputs
if inputs is None or len(inputs) == 0:
Expand Down Expand Up @@ -762,6 +764,7 @@ def get_enough_request(request, scheduled_reqs):
preempted_reqs: list[Request] = []
error_reqs: list[tuple[str, str]] = []
token_budget = self.config.scheduler_config.max_num_batched_tokens
block_size = self.config.cache_config.block_size
need_abort_requests = [] # users trigger abortion

# First, schedule the RUNNING requests.
Expand Down Expand Up @@ -940,6 +943,10 @@ def _allocate_decode_and_extend():
request = self.waiting[0]
if get_enough_request(request, scheduled_reqs):
break
running_req_reserved_block_num = self._get_can_schedule_prefill_threshold_block(0)
if not self.cache_manager.can_allocate_gpu_blocks(running_req_reserved_block_num):
break

if request.status == RequestStatus.WAITING:
result = self.waiting_async_process(request)
if result is None:
Expand All @@ -959,10 +966,13 @@ def _allocate_decode_and_extend():
self.cache_manager.num_cpu_blocks > 0
or self.config.cache_config.kvcache_storage_backend
):
match_token_num, _ = self.cache_manager.pre_match_block_on_gpu(request)
need_prefill_tokens = request.need_prefill_tokens - match_token_num
need_block_num = (need_prefill_tokens + block_size - 1) // block_size
if not self.cache_manager.can_allocate_gpu_blocks(
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
// self.config.cache_config.block_size
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
need_block_num + running_req_reserved_block_num
):
# to prevent block allocation for matching in hierarchical cache and cause dead lock
Comment on lines +969 to +975
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里引入了基于 pre_match_block_on_gpu 的新调度门槛计算(match_token_num / need_block_num + running_req_reserved_block_num),会改变 WAITING 请求在层级缓存开启时的可调度性与 break 条件。仓库里已有 ResourceManagerV1.schedule() 的单测(如 tests/v1/test_schedule_output.py),建议补充覆盖:GPU 前缀命中/未命中、running 保留块不足、以及层级缓存开启时是否能按预期调度/跳过请求的场景,避免回归。

Copilot uses AI. Check for mistakes.
break
success = self.get_prefix_cached_blocks(request)
if not success:
Expand All @@ -986,7 +996,7 @@ def _allocate_decode_and_extend():
continue
num_new_block = self.get_new_block_nums(request, num_new_tokens)
can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block(
request, num_new_block
num_new_block
)
# Allocate blocks to prefill
if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold):
Expand Down Expand Up @@ -1031,10 +1041,13 @@ def _allocate_decode_and_extend():
self.cache_manager.num_cpu_blocks > 0
or self.config.cache_config.kvcache_storage_backend
):
match_token_num, _ = self.cache_manager.pre_match_block_on_gpu(request)
need_prefill_tokens = request.need_prefill_tokens - match_token_num
need_block_num = (need_prefill_tokens + block_size - 1) // block_size
if not self.cache_manager.can_allocate_gpu_blocks(
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
// self.config.cache_config.block_size
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
need_block_num + running_req_reserved_block_num
):
# to prevent block allocation for matching in hierarchical cache and cause dead lock
break
success = self.get_prefix_cached_blocks(request)
if not success:
Expand All @@ -1051,7 +1064,7 @@ def _allocate_decode_and_extend():
continue
num_new_block = self.get_new_block_nums(request, num_new_tokens)
can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block(
request, num_new_block
num_new_block
)
# Allocate blocks to prefill
if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold):
Expand Down Expand Up @@ -1321,17 +1334,21 @@ def preallocate_resource_in_p(self, request: Request):
with self.lock:
if self.available_batch() == 0:
return False

block_size = self.config.cache_config.block_size
request.need_prefill_tokens = len(request.prompt_token_ids)
need_prealloc_prefill_blocks = (
request.need_prefill_tokens + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num
request.need_prefill_tokens + block_size - 1
) // block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num

if self.config.cache_config.enable_prefix_caching:
# Enable prefix caching
if self.cache_manager.num_cpu_blocks > 0 or self.config.cache_config.kvcache_storage_backend:
if not self.cache_manager.can_allocate_gpu_blocks(
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
// self.config.cache_config.block_size
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
match_token_num, _ = self.cache_manager.pre_match_block_on_gpu(request)
need_prefill_tokens = request.need_prefill_tokens - match_token_num
need_block_num = (need_prefill_tokens + block_size - 1) // block_size
if not self.cache_manager.can_allocate_gpu_blocks(need_block_num):
# to prevent block allocation for matching in hierarchical cache and cause dead lock
return False
success = self.get_prefix_cached_blocks(request)
if not success:
Expand Down
Loading