Skip to content

Commit d18aef4

Browse files
author
钮圣虓
committed
feat: refactor cpu cache
1 parent 14ba4a6 commit d18aef4

6 files changed

Lines changed: 186 additions & 197 deletions

File tree

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .tensor_backend import CpuCacheTensorBackend, CpuCacheTensorSpec
2+
3+
__all__ = ["CpuCacheTensorBackend", "CpuCacheTensorSpec"]
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import ctypes
2+
from dataclasses import dataclass
3+
from typing import Optional, Tuple
4+
5+
import numpy as np
6+
import torch
7+
8+
from lightllm.utils.kv_cache_utils import attach_shm_kv_cache_ptr, create_shm_kv_cache_ptr, register_shm_ptr_to_pin
9+
10+
11+
@dataclass(frozen=True)
12+
class CpuCacheTensorSpec:
13+
shm_key: int
14+
shape: Tuple[int, ...]
15+
dtype: torch.dtype
16+
size_bytes: int
17+
pin_on_create: bool = True
18+
19+
20+
class CpuCacheTensorBackend:
21+
def __init__(self, tensor_spec: CpuCacheTensorSpec):
22+
self.tensor_spec = tensor_spec
23+
24+
def create_or_attach(
25+
self, init_shm_data: bool, wait_for_register: bool = True
26+
) -> Tuple[torch.Tensor, Optional[object]]:
27+
return self._init_tensor(create_mode=init_shm_data, wait_for_register=wait_for_register)
28+
29+
def _init_tensor(self, create_mode: bool, wait_for_register: bool) -> Tuple[torch.Tensor, Optional[object]]:
30+
if create_mode:
31+
shm_ptr = create_shm_kv_cache_ptr(key=self.tensor_spec.shm_key, size=self.tensor_spec.size_bytes)
32+
else:
33+
shm_ptr = attach_shm_kv_cache_ptr(key=self.tensor_spec.shm_key, size=self.tensor_spec.size_bytes)
34+
attach_handle = None
35+
if not create_mode or self.tensor_spec.pin_on_create:
36+
attach_handle = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.tensor_spec.size_bytes)
37+
if wait_for_register:
38+
attach_handle.wait()
39+
cpu_cache_tensor = self._build_tensor_view(shm_ptr=shm_ptr)
40+
assert shm_ptr == cpu_cache_tensor.data_ptr()
41+
return cpu_cache_tensor, attach_handle
42+
43+
def _build_tensor_view(self, shm_ptr: int) -> torch.Tensor:
44+
numpy_array = np.frombuffer(
45+
memoryview((ctypes.c_uint8 * self.tensor_spec.size_bytes).from_address(shm_ptr)),
46+
dtype=np.uint8,
47+
)
48+
return torch.from_numpy(numpy_array).view(dtype=self.tensor_spec.dtype).view(self.tensor_spec.shape)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from typing import Optional
2+
3+
from sortedcontainers import SortedSet
4+
5+
6+
class MemoryBlock:
7+
"""内存块类,表示一个连续的内存区域"""
8+
9+
def __init__(self, start, end):
10+
self.start = start
11+
self.end = end
12+
13+
def size(self):
14+
return self.end - self.start
15+
16+
def __repr__(self):
17+
return f"Block(start={self.start}, end={self.end})"
18+
19+
def can_merge(self, block: "MemoryBlock"):
20+
return (self.start == block.end) or (block.start == self.end)
21+
22+
23+
class MemoryManager:
24+
def __init__(self, total_size):
25+
"""
26+
初始化内存管理器
27+
:param total_size: 总内存大小
28+
"""
29+
self.total_size = total_size
30+
self.mem_set_by_start = SortedSet(key=lambda x: (x.start, x.size()))
31+
self.mem_set_by_size = SortedSet(key=lambda x: (x.size(), x.start))
32+
total = MemoryBlock(0, total_size)
33+
self.__add(total)
34+
35+
def alloc(self, need_size: int) -> Optional[MemoryBlock]:
36+
assert need_size > 0
37+
38+
if len(self.mem_set_by_size) == 0:
39+
return None
40+
41+
key = MemoryBlock(start=-1, end=-1 + need_size)
42+
find_index = self.mem_set_by_size.bisect_left(key)
43+
if find_index < len(self.mem_set_by_size):
44+
finded_mem_block: MemoryBlock = self.mem_set_by_size[find_index]
45+
self.__del(finded_mem_block)
46+
ret_mem_block = MemoryBlock(
47+
start=finded_mem_block.start,
48+
end=finded_mem_block.start + need_size,
49+
)
50+
left_block = MemoryBlock(
51+
start=finded_mem_block.start + need_size,
52+
end=finded_mem_block.end,
53+
)
54+
if left_block.size() > 0:
55+
self.__add(left_block)
56+
57+
return ret_mem_block
58+
else:
59+
return None
60+
61+
def release(self, block: MemoryBlock):
62+
if block is None:
63+
return
64+
if len(self.mem_set_by_size) == 0:
65+
self.__add(block)
66+
return
67+
68+
finded_index = self.mem_set_by_start.bisect_left(block)
69+
for index in [finded_index - 1, finded_index, finded_index + 1]:
70+
if index < len(self.mem_set_by_start):
71+
sub_block: MemoryBlock = self.mem_set_by_start[index]
72+
# merge
73+
if block.can_merge(sub_block):
74+
self.__del(sub_block)
75+
merge_block = MemoryBlock(
76+
start=min(block.start, sub_block.start),
77+
end=max(block.end, sub_block.end),
78+
)
79+
self.release(merge_block)
80+
return
81+
# 无法merge时,直接add
82+
self.__add(block)
83+
return
84+
85+
def __add(self, block):
86+
self.mem_set_by_start.add(block)
87+
self.mem_set_by_size.add(block)
88+
89+
def __del(self, block):
90+
self.mem_set_by_start.remove(block)
91+
self.mem_set_by_size.remove(block)

lightllm/server/embed_cache/embed_cache_client.py

Lines changed: 20 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
import ctypes
21
import torch
3-
import numpy as np
4-
from sortedcontainers import SortedSet
5-
from typing import Optional, List
2+
from typing import Optional
63
from lightllm.utils.envs_utils import get_env_start_args
74
from lightllm.utils.log_utils import init_logger
85
from lightllm.utils.embed_utils import calcu_embed_cache_meta
9-
from lightllm.utils.kv_cache_utils import create_shm_kv_cache_ptr, attach_shm_kv_cache_ptr, register_shm_ptr_to_pin
6+
from lightllm.common.cpu_cache import CpuCacheTensorBackend, CpuCacheTensorSpec
7+
from .allocator import MemoryBlock, MemoryManager
8+
from .copy_to_cache import offload_embed_tensor_to_cache
109

1110
logger = init_logger(__name__)
1211

@@ -25,10 +24,11 @@ def __init__(self, create_meta_data: bool, init_shm_data: bool):
2524
if create_meta_data:
2625
self.token_index_manager = MemoryManager(total_size=self.token_num)
2726
else:
28-
if init_shm_data:
29-
self._create_shm_embed_kv_cache()
30-
else:
31-
self._attach_shm_cpu_embed_cache()
27+
tensor_backend = CpuCacheTensorBackend(tensor_spec=self._build_tensor_spec())
28+
self.cpu_embed_cache_tensor, _ = tensor_backend.create_or_attach(
29+
init_shm_data=init_shm_data,
30+
wait_for_register=True,
31+
)
3232
return
3333

3434
def alloc_indexes(self, token_num: int) -> Optional["MemoryBlock"]:
@@ -39,17 +39,14 @@ def release_indexes(self, block: "MemoryBlock"):
3939
return
4040

4141
def copy_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache: int):
42-
from .copy_to_cache import offload_embed_tensor_to_cache
43-
4442
offload_embed_tensor_to_cache(
4543
embed_tensor=embed_tensor,
4644
cache_tensor=self.cpu_embed_cache_tensor,
4745
start_index_in_cache=start_index_in_cache,
4846
)
47+
return
4948

5049
def copy_vision_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache: int):
51-
from .copy_to_cache import offload_embed_tensor_to_cache
52-
5350
if embed_tensor.ndim == 3:
5451
# check for qwen3 vision embed tensor shape, use apply deepstack
5552
assert embed_tensor.shape[1] == self.cpu_embed_cache_tensor.shape[1]
@@ -59,136 +56,19 @@ def copy_vision_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache:
5956
cache_tensor=self.cpu_embed_cache_tensor,
6057
start_index_in_cache=start_index_in_cache,
6158
)
62-
63-
def _create_shm_embed_kv_cache(self):
64-
shm_ptr = create_shm_kv_cache_ptr(
65-
key=self.args.multi_modal_cache_shm_id, size=self.embed_cache_tensor_meta.calcu_size()
66-
)
67-
handle = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.embed_cache_tensor_meta.calcu_size())
68-
handle.wait()
69-
numpy_array = np.frombuffer(
70-
memoryview((ctypes.c_uint8 * self.embed_cache_tensor_meta.calcu_size()).from_address(shm_ptr)),
71-
dtype=np.uint8,
72-
)
73-
# 将 NumPy 数组转换为 PyTorch 张量
74-
shape = (
75-
self.embed_cache_tensor_meta.token_num,
76-
self.embed_cache_tensor_meta.layer_num,
77-
self.embed_cache_tensor_meta.hidden_size,
78-
)
79-
self.cpu_embed_cache_tensor = (
80-
torch.from_numpy(numpy_array).view(dtype=self.embed_cache_tensor_meta.data_type).view(shape)
81-
)
8259
return
8360

84-
def _attach_shm_cpu_embed_cache(self):
85-
shm_ptr = attach_shm_kv_cache_ptr(
86-
key=self.args.multi_modal_cache_shm_id, size=self.embed_cache_tensor_meta.calcu_size()
87-
)
88-
handle = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.embed_cache_tensor_meta.calcu_size())
89-
handle.wait()
90-
numpy_array = np.frombuffer(
91-
memoryview((ctypes.c_uint8 * self.embed_cache_tensor_meta.calcu_size()).from_address(shm_ptr)),
92-
dtype=np.uint8,
93-
)
94-
shape = (
95-
self.embed_cache_tensor_meta.token_num,
96-
self.embed_cache_tensor_meta.layer_num,
97-
self.embed_cache_tensor_meta.hidden_size,
98-
)
99-
self.cpu_embed_cache_tensor = (
100-
torch.from_numpy(numpy_array).view(dtype=self.embed_cache_tensor_meta.data_type).view(shape)
61+
def _build_tensor_spec(self) -> CpuCacheTensorSpec:
62+
return CpuCacheTensorSpec(
63+
shm_key=self.args.multi_modal_cache_shm_id,
64+
shape=(
65+
self.embed_cache_tensor_meta.token_num,
66+
self.embed_cache_tensor_meta.layer_num,
67+
self.embed_cache_tensor_meta.hidden_size,
68+
),
69+
dtype=self.embed_cache_tensor_meta.data_type,
70+
size_bytes=self.embed_cache_tensor_meta.calcu_size(),
10171
)
102-
assert shm_ptr == self.cpu_embed_cache_tensor.data_ptr()
103-
return None
104-
105-
106-
class MemoryBlock:
107-
"""内存块类,表示一个连续的内存区域"""
108-
109-
def __init__(self, start, end):
110-
self.start = start
111-
self.end = end
112-
113-
def size(self):
114-
return self.end - self.start
115-
116-
def __repr__(self):
117-
return f"Block(start={self.start}, end={self.end})"
118-
119-
def can_merge(self, block: "MemoryBlock"):
120-
return (self.start == block.end) or (block.start == self.end)
121-
122-
123-
class MemoryManager:
124-
def __init__(self, total_size):
125-
"""
126-
初始化内存管理器
127-
:param total_size: 总内存大小
128-
"""
129-
self.total_size = total_size
130-
self.mem_set_by_start = SortedSet(key=lambda x: (x.start, x.size()))
131-
self.mem_set_by_size = SortedSet(key=lambda x: (x.size(), x.start))
132-
total = MemoryBlock(0, total_size)
133-
self.__add(total)
134-
135-
def alloc(self, need_size: int) -> Optional[MemoryBlock]:
136-
assert need_size > 0
137-
138-
if len(self.mem_set_by_size) == 0:
139-
return None
140-
141-
key = MemoryBlock(start=-1, end=-1 + need_size)
142-
find_index = self.mem_set_by_size.bisect_left(key)
143-
if find_index < len(self.mem_set_by_size):
144-
finded_mem_block: MemoryBlock = self.mem_set_by_size[find_index]
145-
self.__del(finded_mem_block)
146-
ret_mem_block = MemoryBlock(
147-
start=finded_mem_block.start,
148-
end=finded_mem_block.start + need_size,
149-
)
150-
left_block = MemoryBlock(
151-
start=finded_mem_block.start + need_size,
152-
end=finded_mem_block.end,
153-
)
154-
if left_block.size() > 0:
155-
self.__add(left_block)
156-
157-
return ret_mem_block
158-
else:
159-
return None
160-
161-
def release(self, block: MemoryBlock):
162-
if block is None:
163-
return
164-
if len(self.mem_set_by_size) == 0:
165-
self.__add(block)
166-
return
167-
168-
finded_index = self.mem_set_by_start.bisect_left(block)
169-
for index in [finded_index - 1, finded_index, finded_index + 1]:
170-
if index < len(self.mem_set_by_start):
171-
sub_block: MemoryBlock = self.mem_set_by_start[index]
172-
# merge
173-
if block.can_merge(sub_block):
174-
self.__del(sub_block)
175-
merge_block = MemoryBlock(
176-
start=min(block.start, sub_block.start),
177-
end=max(block.end, sub_block.end),
178-
)
179-
self.release(merge_block)
180-
return
181-
# 无法merge时,直接add
182-
self.__add(block)
183-
return
184-
185-
def __add(self, block):
186-
self.mem_set_by_start.add(block)
187-
self.mem_set_by_size.add(block)
188-
189-
def __del(self, block):
190-
self.mem_set_by_start.remove(block)
191-
self.mem_set_by_size.remove(block)
19272

19373

19474
if __name__ == "__main__":

lightllm/server/embed_cache/impl/naive_memory_cache.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
import multiprocessing.shared_memory as shm
1111
from ..utils import get_shm_name_data, free_shm
1212
from lightllm.utils.log_utils import init_logger
13-
from ..embed_cache_client import CpuEmbedCacheClient, MemoryBlock, SortedSet
13+
from sortedcontainers import SortedSet
14+
15+
from ..allocator import MemoryBlock
16+
from ..embed_cache_client import CpuEmbedCacheClient
1417

1518
logger = init_logger(__name__)
1619

0 commit comments

Comments
 (0)