-
Notifications
You must be signed in to change notification settings - Fork 309
feat: refactor cpu cache #1236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: refactor cpu cache #1236
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from .tensor_backend import CpuCacheTensorBackend, CpuCacheTensorSpec | ||
|
|
||
| __all__ = ["CpuCacheTensorBackend", "CpuCacheTensorSpec"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| import ctypes | ||
| from dataclasses import dataclass | ||
| from typing import Optional, Tuple | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from lightllm.utils.kv_cache_utils import attach_shm_kv_cache_ptr, create_shm_kv_cache_ptr, register_shm_ptr_to_pin | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class CpuCacheTensorSpec: | ||
| shm_key: int | ||
| shape: Tuple[int, ...] | ||
| dtype: torch.dtype | ||
| size_bytes: int | ||
| pin_on_create: bool = True | ||
|
|
||
|
|
||
| class CpuCacheTensorBackend: | ||
| def __init__(self, tensor_spec: CpuCacheTensorSpec): | ||
| self.tensor_spec = tensor_spec | ||
|
|
||
| def create_or_attach( | ||
| self, init_shm_data: bool, wait_for_register: bool = True | ||
| ) -> Tuple[torch.Tensor, Optional[object]]: | ||
| return self._init_tensor(create_mode=init_shm_data, wait_for_register=wait_for_register) | ||
|
|
||
| def _init_tensor(self, create_mode: bool, wait_for_register: bool) -> Tuple[torch.Tensor, Optional[object]]: | ||
| if create_mode: | ||
| shm_ptr = create_shm_kv_cache_ptr(key=self.tensor_spec.shm_key, size=self.tensor_spec.size_bytes) | ||
| else: | ||
| shm_ptr = attach_shm_kv_cache_ptr(key=self.tensor_spec.shm_key, size=self.tensor_spec.size_bytes) | ||
| attach_handle = None | ||
| if not create_mode or self.tensor_spec.pin_on_create: | ||
| attach_handle = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.tensor_spec.size_bytes) | ||
| if wait_for_register: | ||
| attach_handle.wait() | ||
| cpu_cache_tensor = self._build_tensor_view(shm_ptr=shm_ptr) | ||
| assert shm_ptr == cpu_cache_tensor.data_ptr() | ||
| return cpu_cache_tensor, attach_handle | ||
|
|
||
| def _build_tensor_view(self, shm_ptr: int) -> torch.Tensor: | ||
| numpy_array = np.frombuffer( | ||
| memoryview((ctypes.c_uint8 * self.tensor_spec.size_bytes).from_address(shm_ptr)), | ||
| dtype=np.uint8, | ||
| ) | ||
| return torch.from_numpy(numpy_array).view(dtype=self.tensor_spec.dtype).view(self.tensor_spec.shape) |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,91 @@ | ||||||||||||||||||
| from typing import Optional | ||||||||||||||||||
|
|
||||||||||||||||||
| from sortedcontainers import SortedSet | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class MemoryBlock: | ||||||||||||||||||
| """内存块类,表示一个连续的内存区域""" | ||||||||||||||||||
|
|
||||||||||||||||||
| def __init__(self, start, end): | ||||||||||||||||||
| self.start = start | ||||||||||||||||||
| self.end = end | ||||||||||||||||||
|
|
||||||||||||||||||
| def size(self): | ||||||||||||||||||
| return self.end - self.start | ||||||||||||||||||
|
|
||||||||||||||||||
| def __repr__(self): | ||||||||||||||||||
| return f"Block(start={self.start}, end={self.end})" | ||||||||||||||||||
|
|
||||||||||||||||||
| def can_merge(self, block: "MemoryBlock"): | ||||||||||||||||||
| return (self.start == block.end) or (block.start == self.end) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class MemoryManager: | ||||||||||||||||||
| def __init__(self, total_size): | ||||||||||||||||||
| """ | ||||||||||||||||||
| 初始化内存管理器 | ||||||||||||||||||
| :param total_size: 总内存大小 | ||||||||||||||||||
| """ | ||||||||||||||||||
|
Comment on lines
+25
to
+28
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For consistency with the rest of the codebase which is in English, it would be better to write docstrings and comments in English. This improves maintainability for a wider audience.
Suggested change
|
||||||||||||||||||
| self.total_size = total_size | ||||||||||||||||||
| self.mem_set_by_start = SortedSet(key=lambda x: (x.start, x.size())) | ||||||||||||||||||
| self.mem_set_by_size = SortedSet(key=lambda x: (x.size(), x.start)) | ||||||||||||||||||
| total = MemoryBlock(0, total_size) | ||||||||||||||||||
| self.__add(total) | ||||||||||||||||||
|
|
||||||||||||||||||
| def alloc(self, need_size: int) -> Optional[MemoryBlock]: | ||||||||||||||||||
| assert need_size > 0 | ||||||||||||||||||
|
|
||||||||||||||||||
| if len(self.mem_set_by_size) == 0: | ||||||||||||||||||
| return None | ||||||||||||||||||
|
|
||||||||||||||||||
| key = MemoryBlock(start=-1, end=-1 + need_size) | ||||||||||||||||||
| find_index = self.mem_set_by_size.bisect_left(key) | ||||||||||||||||||
| if find_index < len(self.mem_set_by_size): | ||||||||||||||||||
| finded_mem_block: MemoryBlock = self.mem_set_by_size[find_index] | ||||||||||||||||||
| self.__del(finded_mem_block) | ||||||||||||||||||
| ret_mem_block = MemoryBlock( | ||||||||||||||||||
| start=finded_mem_block.start, | ||||||||||||||||||
| end=finded_mem_block.start + need_size, | ||||||||||||||||||
| ) | ||||||||||||||||||
| left_block = MemoryBlock( | ||||||||||||||||||
| start=finded_mem_block.start + need_size, | ||||||||||||||||||
| end=finded_mem_block.end, | ||||||||||||||||||
| ) | ||||||||||||||||||
| if left_block.size() > 0: | ||||||||||||||||||
| self.__add(left_block) | ||||||||||||||||||
|
|
||||||||||||||||||
| return ret_mem_block | ||||||||||||||||||
| else: | ||||||||||||||||||
| return None | ||||||||||||||||||
|
|
||||||||||||||||||
| def release(self, block: MemoryBlock): | ||||||||||||||||||
| if block is None: | ||||||||||||||||||
| return | ||||||||||||||||||
| if len(self.mem_set_by_size) == 0: | ||||||||||||||||||
| self.__add(block) | ||||||||||||||||||
| return | ||||||||||||||||||
|
|
||||||||||||||||||
| finded_index = self.mem_set_by_start.bisect_left(block) | ||||||||||||||||||
| for index in [finded_index - 1, finded_index, finded_index + 1]: | ||||||||||||||||||
| if index < len(self.mem_set_by_start): | ||||||||||||||||||
|
Comment on lines
+69
to
+70
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for finding adjacent blocks to merge has a bug and can be optimized.
Suggested change
|
||||||||||||||||||
| sub_block: MemoryBlock = self.mem_set_by_start[index] | ||||||||||||||||||
| # merge | ||||||||||||||||||
| if block.can_merge(sub_block): | ||||||||||||||||||
| self.__del(sub_block) | ||||||||||||||||||
| merge_block = MemoryBlock( | ||||||||||||||||||
| start=min(block.start, sub_block.start), | ||||||||||||||||||
| end=max(block.end, sub_block.end), | ||||||||||||||||||
| ) | ||||||||||||||||||
| self.release(merge_block) | ||||||||||||||||||
| return | ||||||||||||||||||
| # 无法merge时,直接add | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||
| self.__add(block) | ||||||||||||||||||
| return | ||||||||||||||||||
|
|
||||||||||||||||||
| def __add(self, block): | ||||||||||||||||||
| self.mem_set_by_start.add(block) | ||||||||||||||||||
| self.mem_set_by_size.add(block) | ||||||||||||||||||
|
|
||||||||||||||||||
| def __del(self, block): | ||||||||||||||||||
| self.mem_set_by_start.remove(block) | ||||||||||||||||||
| self.mem_set_by_size.remove(block) | ||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with the rest of the codebase which is in English, it would be better to write docstrings and comments in English. This improves maintainability for a wider audience.