-
Notifications
You must be signed in to change notification settings - Fork 729
[Scheduler] Pre match radix tree in schedule #6989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -816,8 +816,8 @@ def request_match_blocks(self, task: Request, block_size, *args): | |
| # 2. prepare cpu cache: allocate gpu cache for matched cpu blocks, wait for data transfer to complete | ||
| gpu_recv_block_ids = [] | ||
| match_cpu_blocks_num = len(match_cpu_block_ids) | ||
| if self.can_allocate_gpu_blocks(num_blocks=match_cpu_blocks_num): | ||
| if match_cpu_blocks_num > 0: | ||
| if match_cpu_blocks_num > 0: | ||
| if self.can_allocate_gpu_blocks(num_blocks=match_cpu_blocks_num): | ||
| logger.debug( | ||
| f"request_match_blocks: req_id {req_id}, allocate {match_cpu_blocks_num} block to receive cpu cache" | ||
| ) | ||
|
|
@@ -833,10 +833,10 @@ def request_match_blocks(self, task: Request, block_size, *args): | |
| ) | ||
| cost_time = time.time() - start_time | ||
| metrics["cpu_cache_prepare_time"] = cost_time | ||
| else: | ||
| raise Exception( | ||
| "request_match_blocks: Not enough GPU memory to allocate cache for matched CPU Cache" | ||
| ) | ||
| else: | ||
| raise Exception( | ||
| "request_match_blocks: Not enough GPU memory to allocate cache for matched CPU Cache" | ||
| ) | ||
|
|
||
| # 3. match and prefetch cache from storage | ||
| match_token_num = gpu_match_token_num + cpu_match_token_num | ||
|
|
@@ -1749,6 +1749,76 @@ def mm_match_block(self, request, block_size): | |
| cpu_match_token_num, | ||
| ) | ||
|
|
||
| def pre_match_block_on_gpu(self, request): | ||
| """ | ||
| Pre-match request tokens against cached GPU blocks in the radix tree. | ||
|
|
||
| This method performs a prefix matching operation to find the longest sequence | ||
| of tokens that already exist in GPU cache blocks. It traverses the radix tree | ||
| from the root, computing hash values for each block-sized chunk of tokens | ||
| and checking if corresponding nodes exist with GPU-resident data. | ||
|
|
||
| Args: | ||
| request: The inference request object containing prompt_token_ids and | ||
| output_token_ids to be matched against the cache. | ||
|
|
||
| Returns: | ||
| tuple: A tuple containing: | ||
| - match_token_num (int): The total number of tokens that were | ||
| successfully matched in GPU-resident blocks. | ||
| - last_node (BlockNode): The last matched node in the radix tree, | ||
| which represents the deepest point of prefix cache hit. | ||
|
|
||
| Note: | ||
| - Only blocks with `has_in_gpu=True` are considered as valid matches. | ||
| - The matching stops at the first mismatch or when a block is not in GPU. | ||
| - This is a read-only operation that does not modify the radix tree | ||
| or LRU data structures. | ||
| """ | ||
| if isinstance(request.prompt_token_ids, np.ndarray): | ||
| prompt_token_ids = request.prompt_token_ids.tolist() | ||
| else: | ||
| prompt_token_ids = request.prompt_token_ids | ||
| input_ids = prompt_token_ids + request.output_token_ids | ||
| total_token_num = len(input_ids) | ||
|
|
||
| last_node = self.radix_tree_root | ||
| match_token_num = 0 | ||
| mm_idx = 0 | ||
| prefix_block_key = [] | ||
| block_size = self.config.cache_config.block_size | ||
|
|
||
| with self.cache_status_lock: | ||
| while match_token_num < total_token_num: | ||
| token_block = input_ids[match_token_num : match_token_num + block_size] | ||
| if len(token_block) != block_size: | ||
| break | ||
juncaipeng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| mm_idx, extra_keys = self.get_block_hash_extra_keys( | ||
| request=request, | ||
| start_idx=match_token_num, | ||
| end_idx=match_token_num + block_size, | ||
| mm_idx=mm_idx, | ||
| ) | ||
| prefix_block_key.extend(extra_keys) | ||
| hash_value = get_hash_str(token_block, prefix_block_key) | ||
| prefix_block_key = [hash_value] | ||
|
|
||
| if hash_value not in last_node.children: | ||
| break | ||
|
|
||
| child = last_node.children[hash_value] | ||
| if not child.has_in_gpu: | ||
| break | ||
juncaipeng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| match_token_num += block_size | ||
| last_node = child | ||
|
|
||
| logger.info(f"pre_match_block_on_gpu: req_id {request.request_id}, match_token_num {match_token_num}") | ||
| return ( | ||
| match_token_num, | ||
| last_node, | ||
| ) | ||
|
Comment on lines
+1752
to
+1820
|
||
|
|
||
| def match_block(self, req_id, input_ids, block_size): | ||
| """ | ||
| Args: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -398,7 +398,7 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re | |
| self.can_relax_prefill_strategy = False | ||
| return can_schedule | ||
|
|
||
| def _get_can_schedule_prefill_threshold_block(self, request, num_chunk_new_block): | ||
| def _get_can_schedule_prefill_threshold_block(self, num_chunk_new_block): | ||
| if self.can_relax_prefill_strategy: | ||
| can_schedule_block_num_threshold = num_chunk_new_block | ||
| else: | ||
|
|
@@ -412,7 +412,7 @@ def _get_can_schedule_prefill_threshold_block(self, request, num_chunk_new_block | |
| return can_schedule_block_num_threshold | ||
|
|
||
| def _update_mm_hashes(self, request): | ||
| if request.multimodal_inputs is None: | ||
| if request.multimodal_inputs is None or request.has_update_mm_hashes: | ||
| return | ||
|
|
||
| inputs = request.multimodal_inputs | ||
|
|
@@ -450,6 +450,8 @@ def _update_mm_hashes(self, request): | |
| inputs["mm_positions"] = [] | ||
| inputs["mm_hashes"] = [] | ||
|
|
||
| request.has_update_mm_hashes = True | ||
|
|
||
| def _is_mm_request(self, request): | ||
| inputs = request.multimodal_inputs | ||
| if inputs is None or len(inputs) == 0: | ||
|
|
@@ -762,6 +764,7 @@ def get_enough_request(request, scheduled_reqs): | |
| preempted_reqs: list[Request] = [] | ||
| error_reqs: list[tuple[str, str]] = [] | ||
| token_budget = self.config.scheduler_config.max_num_batched_tokens | ||
| block_size = self.config.cache_config.block_size | ||
| need_abort_requests = [] # users trigger abortion | ||
|
|
||
| # First, schedule the RUNNING requests. | ||
|
|
@@ -940,6 +943,10 @@ def _allocate_decode_and_extend(): | |
| request = self.waiting[0] | ||
| if get_enough_request(request, scheduled_reqs): | ||
| break | ||
| running_req_reserved_block_num = self._get_can_schedule_prefill_threshold_block(0) | ||
| if not self.cache_manager.can_allocate_gpu_blocks(running_req_reserved_block_num): | ||
| break | ||
|
|
||
| if request.status == RequestStatus.WAITING: | ||
| result = self.waiting_async_process(request) | ||
| if result is None: | ||
|
|
@@ -959,10 +966,13 @@ def _allocate_decode_and_extend(): | |
| self.cache_manager.num_cpu_blocks > 0 | ||
| or self.config.cache_config.kvcache_storage_backend | ||
| ): | ||
| match_token_num, _ = self.cache_manager.pre_match_block_on_gpu(request) | ||
| need_prefill_tokens = request.need_prefill_tokens - match_token_num | ||
| need_block_num = (need_prefill_tokens + block_size - 1) // block_size | ||
| if not self.cache_manager.can_allocate_gpu_blocks( | ||
| (request.need_prefill_tokens + self.config.cache_config.block_size - 1) | ||
| // self.config.cache_config.block_size | ||
| ): # to prevent block allocation for matching in hierarchical cache and cause dead lock | ||
| need_block_num + running_req_reserved_block_num | ||
| ): | ||
| # to prevent block allocation for matching in hierarchical cache and cause dead lock | ||
juncaipeng marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+969
to
+975
|
||
| break | ||
| success = self.get_prefix_cached_blocks(request) | ||
| if not success: | ||
|
|
@@ -986,7 +996,7 @@ def _allocate_decode_and_extend(): | |
| continue | ||
| num_new_block = self.get_new_block_nums(request, num_new_tokens) | ||
| can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block( | ||
| request, num_new_block | ||
| num_new_block | ||
| ) | ||
| # Allocate blocks to prefill | ||
| if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold): | ||
|
|
@@ -1031,10 +1041,13 @@ def _allocate_decode_and_extend(): | |
| self.cache_manager.num_cpu_blocks > 0 | ||
| or self.config.cache_config.kvcache_storage_backend | ||
| ): | ||
| match_token_num, _ = self.cache_manager.pre_match_block_on_gpu(request) | ||
| need_prefill_tokens = request.need_prefill_tokens - match_token_num | ||
| need_block_num = (need_prefill_tokens + block_size - 1) // block_size | ||
| if not self.cache_manager.can_allocate_gpu_blocks( | ||
| (request.need_prefill_tokens + self.config.cache_config.block_size - 1) | ||
| // self.config.cache_config.block_size | ||
| ): # to prevent block allocation for matching in hierarchical cache and cause dead lock | ||
| need_block_num + running_req_reserved_block_num | ||
| ): | ||
| # to prevent block allocation for matching in hierarchical cache and cause dead lock | ||
| break | ||
| success = self.get_prefix_cached_blocks(request) | ||
| if not success: | ||
|
|
@@ -1051,7 +1064,7 @@ def _allocate_decode_and_extend(): | |
| continue | ||
| num_new_block = self.get_new_block_nums(request, num_new_tokens) | ||
| can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block( | ||
| request, num_new_block | ||
| num_new_block | ||
| ) | ||
| # Allocate blocks to prefill | ||
| if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold): | ||
|
|
@@ -1321,17 +1334,21 @@ def preallocate_resource_in_p(self, request: Request): | |
| with self.lock: | ||
| if self.available_batch() == 0: | ||
| return False | ||
|
|
||
| block_size = self.config.cache_config.block_size | ||
| request.need_prefill_tokens = len(request.prompt_token_ids) | ||
| need_prealloc_prefill_blocks = ( | ||
| request.need_prefill_tokens + self.config.cache_config.block_size - 1 | ||
| ) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num | ||
| request.need_prefill_tokens + block_size - 1 | ||
| ) // block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num | ||
|
|
||
| if self.config.cache_config.enable_prefix_caching: | ||
| # Enable prefix caching | ||
| if self.cache_manager.num_cpu_blocks > 0 or self.config.cache_config.kvcache_storage_backend: | ||
| if not self.cache_manager.can_allocate_gpu_blocks( | ||
| (request.need_prefill_tokens + self.config.cache_config.block_size - 1) | ||
| // self.config.cache_config.block_size | ||
| ): # to prevent block allocation for matching in hierarchical cache and cause dead lock | ||
| match_token_num, _ = self.cache_manager.pre_match_block_on_gpu(request) | ||
| need_prefill_tokens = request.need_prefill_tokens - match_token_num | ||
| need_block_num = (need_prefill_tokens + block_size - 1) // block_size | ||
| if not self.cache_manager.can_allocate_gpu_blocks(need_block_num): | ||
| # to prevent block allocation for matching in hierarchical cache and cause dead lock | ||
| return False | ||
| success = self.get_prefix_cached_blocks(request) | ||
| if not success: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pre_match_block_on_gpu内部直接读取self.config.cache_config.block_size,同时在调度侧已经显式拿到了block_size并据此计算need_block_num。为了避免未来 block_size 来源不一致导致匹配/预估不一致,建议让该方法接收block_size参数(与现有mm_match_block(..., block_size)/request_match_blocks(..., block_size)保持一致),并在内部使用传入值。