-
Notifications
You must be signed in to change notification settings - Fork 325
refactor(kv-cache): embed KvCacheAllocator in MemoryManager as allocator #1301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,10 +3,11 @@ | |||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||
| import torch.distributed as dist | ||||||||||||||||||||||||||
| import torch.multiprocessing as mp | ||||||||||||||||||||||||||
| from typing import List, Union, Tuple, Any | ||||||||||||||||||||||||||
| from typing import List, Tuple, Any, Union | ||||||||||||||||||||||||||
| from lightllm.server.pd_io_struct import KVMoveTask | ||||||||||||||||||||||||||
| from lightllm.utils.log_utils import init_logger | ||||||||||||||||||||||||||
| from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt | ||||||||||||||||||||||||||
| from .allocator import KvCacheAllocator | ||||||||||||||||||||||||||
| from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory | ||||||||||||||||||||||||||
| from lightllm.common.kv_trans_kernel.kv_trans import kv_trans | ||||||||||||||||||||||||||
| from lightllm.utils.dist_utils import get_current_rank_in_node, get_node_world_size | ||||||||||||||||||||||||||
|
|
@@ -38,27 +39,8 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False | |||||||||||||||||||||||||
| # profile the max total token num if the size is None | ||||||||||||||||||||||||||
| self.profile_size(mem_fraction) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| self.mem_state = torch.arange( | ||||||||||||||||||||||||||
| 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| self._mem_state_return = torch.arange( | ||||||||||||||||||||||||||
| 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| self._return_start = 0 | ||||||||||||||||||||||||||
| self.mark_start = 0 | ||||||||||||||||||||||||||
| self.mark_end = self.size | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| self.can_use_mem_size = self.size | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 | ||||||||||||||||||||||||||
| from lightllm.utils.envs_utils import get_unique_server_name | ||||||||||||||||||||||||||
| self.allocator = KvCacheAllocator(self.size) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| rank_in_node = get_current_rank_in_node() | ||||||||||||||||||||||||||
| self.shared_can_use_token_num = SharedInt( | ||||||||||||||||||||||||||
| f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| self.shared_can_use_token_num.set_value(self.can_use_mem_size) | ||||||||||||||||||||||||||
| self._init_buffers( | ||||||||||||||||||||||||||
| self.size, | ||||||||||||||||||||||||||
| dtype, | ||||||||||||||||||||||||||
|
|
@@ -338,57 +320,13 @@ def _free_buffers(self): | |||||||||||||||||||||||||
| self.kv_buffer = None | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def alloc(self, need_size) -> torch.Tensor: | ||||||||||||||||||||||||||
| if need_size > self.mark_end - self.mark_start: | ||||||||||||||||||||||||||
| logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") | ||||||||||||||||||||||||||
| assert False, "error alloc state" | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| start = self.mark_start | ||||||||||||||||||||||||||
| end = self.mark_start + need_size | ||||||||||||||||||||||||||
| self.mark_start += need_size | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| self.can_use_mem_size -= need_size | ||||||||||||||||||||||||||
| self.shared_can_use_token_num.set_value(self.can_use_mem_size) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # 利用缓冲区返回,避免异步情况下的内存竞争 | ||||||||||||||||||||||||||
| if self._return_start + need_size > self._mem_state_return.shape[0]: | ||||||||||||||||||||||||||
| self._return_start = 0 | ||||||||||||||||||||||||||
| ans = self._mem_state_return[self._return_start : self._return_start + need_size] | ||||||||||||||||||||||||||
| ans.copy_(self.mem_state[start:end]) | ||||||||||||||||||||||||||
| self._return_start += need_size | ||||||||||||||||||||||||||
| return ans | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def free(self, free_index: Union[torch.Tensor, List[int]]): | ||||||||||||||||||||||||||
| """_summary_ | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||
| free_index (torch.Tensor): _description_ | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| return self.allocator.alloc(need_size) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| end = self.mark_start | ||||||||||||||||||||||||||
| start = self.mark_start - len(free_index) | ||||||||||||||||||||||||||
| assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if isinstance(free_index, list): | ||||||||||||||||||||||||||
| self.mem_state.numpy()[start:end] = free_index | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 | ||||||||||||||||||||||||||
| self.mem_state[start:end] = free_index | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| self.mark_start -= len(free_index) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| self.can_use_mem_size += len(free_index) | ||||||||||||||||||||||||||
| self.shared_can_use_token_num.set_value(self.can_use_mem_size) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if self.can_use_mem_size == len(self.mem_state): | ||||||||||||||||||||||||||
| logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") | ||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||
| def free(self, free_index: Union[torch.Tensor, List[int]]) -> None: | ||||||||||||||||||||||||||
| self.allocator.free(free_index) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def free_all(self): | ||||||||||||||||||||||||||
| self.can_use_mem_size = len(self.mem_state) | ||||||||||||||||||||||||||
| self.shared_can_use_token_num.set_value(self.can_use_mem_size) | ||||||||||||||||||||||||||
| self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) | ||||||||||||||||||||||||||
| self.mark_start = 0 | ||||||||||||||||||||||||||
| self.mark_end = len(self.mem_state) | ||||||||||||||||||||||||||
| self.allocator.free_all() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def resize_mem(self, new_size): | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
|
|
@@ -401,13 +339,7 @@ def resize_mem(self, new_size): | |||||||||||||||||||||||||
| layer_num = self.layer_num | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| self.size = new_size | ||||||||||||||||||||||||||
| self.mem_state = torch.arange( | ||||||||||||||||||||||||||
| 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| self.mark_start = 0 | ||||||||||||||||||||||||||
| self.mark_end = self.size | ||||||||||||||||||||||||||
| self.can_use_mem_size = self.size | ||||||||||||||||||||||||||
| self.shared_can_use_token_num.set_value(self.can_use_mem_size) | ||||||||||||||||||||||||||
| self.allocator.resize(new_size) | ||||||||||||||||||||||||||
|
Comment on lines
341
to
+342
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||||||
| self._free_buffers() | ||||||||||||||||||||||||||
| self._init_buffers(size, dtype, head_num, head_dim, layer_num) | ||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -280,8 +280,8 @@ def _filter(self, finished_request_ids: List[int]): | |||||
| f"free a batch state:\n" | ||||||
| f"radix refed token num {self.radix_cache.get_refed_tokens_num()}\n" | ||||||
| f"radix hold token num {self.radix_cache.get_tree_total_tokens_num()}\n" | ||||||
| f"mem manager can alloc token num {self.req_manager.mem_manager.can_use_mem_size}\n" | ||||||
| f"mem manager total size {self.req_manager.mem_manager.size}" | ||||||
| f"mem manager can alloc token num {self.req_manager.mem_manager.allocator.can_use_mem_size}\n" | ||||||
| f"mem manager total size {self.req_manager.mem_manager.allocator.size}\n" | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The addition of a trailing newline
Suggested change
|
||||||
| ) | ||||||
|
|
||||||
| return | ||||||
|
|
@@ -348,7 +348,7 @@ def get_can_alloc_token_num(self): | |||||
| radix_cache_unref_token_num = ( | ||||||
| self.radix_cache.get_tree_total_tokens_num() - self.radix_cache.get_refed_tokens_num() | ||||||
| ) | ||||||
| return self.req_manager.mem_manager.can_use_mem_size + radix_cache_unref_token_num | ||||||
| return self.req_manager.mem_manager.allocator.can_use_mem_size + radix_cache_unref_token_num | ||||||
|
|
||||||
| def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: List["InferReq"]): | ||||||
| """ | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -80,8 +80,8 @@ def _alloc_to_frozen_some_tokens(self, move_task: KVMoveTask): | |||||
| logger.debug( | ||||||
| f"radix refed token num {self.backend.radix_cache.get_refed_tokens_num()}\n" | ||||||
| f"radix hold token num {self.backend.radix_cache.get_tree_total_tokens_num()}\n" | ||||||
| f"mem manager can alloc token num {self.backend.model.mem_manager.can_use_mem_size}\n" | ||||||
| f"mem manager total size {self.backend.model.mem_manager.size}" | ||||||
| f"mem manager can alloc token num {self.backend.model.mem_manager.allocator.can_use_mem_size}\n" | ||||||
| f"mem manager total size {self.backend.model.mem_manager.allocator.size}\n" | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the change in
Suggested change
|
||||||
| f"frozened token num {frozen_token_num}\n" | ||||||
| f"estimated peak token num {estimated_peak_token_num}\n" | ||||||
| ) | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with the
freemethod, please add type hints for the parameters and return values ofalloc,free_all, andresize_mem.