1- import ctypes
21import torch
3- import numpy as np
4- from sortedcontainers import SortedSet
5- from typing import Optional , List
2+ from typing import Optional
63from lightllm .utils .envs_utils import get_env_start_args
74from lightllm .utils .log_utils import init_logger
85from lightllm .utils .embed_utils import calcu_embed_cache_meta
9- from lightllm .utils .kv_cache_utils import create_shm_kv_cache_ptr , attach_shm_kv_cache_ptr , register_shm_ptr_to_pin
6+ from lightllm .common .cpu_cache import CpuCacheTensorBackend , CpuCacheTensorSpec
7+ from .allocator import MemoryBlock , MemoryManager
8+ from .copy_to_cache import offload_embed_tensor_to_cache
109
1110logger = init_logger (__name__ )
1211
@@ -25,10 +24,11 @@ def __init__(self, create_meta_data: bool, init_shm_data: bool):
2524 if create_meta_data :
2625 self .token_index_manager = MemoryManager (total_size = self .token_num )
2726 else :
28- if init_shm_data :
29- self ._create_shm_embed_kv_cache ()
30- else :
31- self ._attach_shm_cpu_embed_cache ()
27+ tensor_backend = CpuCacheTensorBackend (tensor_spec = self ._build_tensor_spec ())
28+ self .cpu_embed_cache_tensor , _ = tensor_backend .create_or_attach (
29+ init_shm_data = init_shm_data ,
30+ wait_for_register = True ,
31+ )
3232 return
3333
3434 def alloc_indexes (self , token_num : int ) -> Optional ["MemoryBlock" ]:
@@ -39,17 +39,14 @@ def release_indexes(self, block: "MemoryBlock"):
3939 return
4040
4141 def copy_to_cache (self , embed_tensor : torch .Tensor , start_index_in_cache : int ):
42- from .copy_to_cache import offload_embed_tensor_to_cache
43-
4442 offload_embed_tensor_to_cache (
4543 embed_tensor = embed_tensor ,
4644 cache_tensor = self .cpu_embed_cache_tensor ,
4745 start_index_in_cache = start_index_in_cache ,
4846 )
47+ return
4948
5049 def copy_vision_to_cache (self , embed_tensor : torch .Tensor , start_index_in_cache : int ):
51- from .copy_to_cache import offload_embed_tensor_to_cache
52-
5350 if embed_tensor .ndim == 3 :
5451 # check for qwen3 vision embed tensor shape, use apply deepstack
5552 assert embed_tensor .shape [1 ] == self .cpu_embed_cache_tensor .shape [1 ]
@@ -59,136 +56,19 @@ def copy_vision_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache:
5956 cache_tensor = self .cpu_embed_cache_tensor ,
6057 start_index_in_cache = start_index_in_cache ,
6158 )
62-
63- def _create_shm_embed_kv_cache (self ):
64- shm_ptr = create_shm_kv_cache_ptr (
65- key = self .args .multi_modal_cache_shm_id , size = self .embed_cache_tensor_meta .calcu_size ()
66- )
67- handle = register_shm_ptr_to_pin (shm_ptr = shm_ptr , size = self .embed_cache_tensor_meta .calcu_size ())
68- handle .wait ()
69- numpy_array = np .frombuffer (
70- memoryview ((ctypes .c_uint8 * self .embed_cache_tensor_meta .calcu_size ()).from_address (shm_ptr )),
71- dtype = np .uint8 ,
72- )
73- # 将 NumPy 数组转换为 PyTorch 张量
74- shape = (
75- self .embed_cache_tensor_meta .token_num ,
76- self .embed_cache_tensor_meta .layer_num ,
77- self .embed_cache_tensor_meta .hidden_size ,
78- )
79- self .cpu_embed_cache_tensor = (
80- torch .from_numpy (numpy_array ).view (dtype = self .embed_cache_tensor_meta .data_type ).view (shape )
81- )
8259 return
8360
84- def _attach_shm_cpu_embed_cache (self ):
85- shm_ptr = attach_shm_kv_cache_ptr (
86- key = self .args .multi_modal_cache_shm_id , size = self .embed_cache_tensor_meta .calcu_size ()
87- )
88- handle = register_shm_ptr_to_pin (shm_ptr = shm_ptr , size = self .embed_cache_tensor_meta .calcu_size ())
89- handle .wait ()
90- numpy_array = np .frombuffer (
91- memoryview ((ctypes .c_uint8 * self .embed_cache_tensor_meta .calcu_size ()).from_address (shm_ptr )),
92- dtype = np .uint8 ,
93- )
94- shape = (
95- self .embed_cache_tensor_meta .token_num ,
96- self .embed_cache_tensor_meta .layer_num ,
97- self .embed_cache_tensor_meta .hidden_size ,
98- )
99- self .cpu_embed_cache_tensor = (
100- torch .from_numpy (numpy_array ).view (dtype = self .embed_cache_tensor_meta .data_type ).view (shape )
61+ def _build_tensor_spec (self ) -> CpuCacheTensorSpec :
62+ return CpuCacheTensorSpec (
63+ shm_key = self .args .multi_modal_cache_shm_id ,
64+ shape = (
65+ self .embed_cache_tensor_meta .token_num ,
66+ self .embed_cache_tensor_meta .layer_num ,
67+ self .embed_cache_tensor_meta .hidden_size ,
68+ ),
69+ dtype = self .embed_cache_tensor_meta .data_type ,
70+ size_bytes = self .embed_cache_tensor_meta .calcu_size (),
10171 )
102- assert shm_ptr == self .cpu_embed_cache_tensor .data_ptr ()
103- return None
104-
105-
106- class MemoryBlock :
107- """内存块类,表示一个连续的内存区域"""
108-
109- def __init__ (self , start , end ):
110- self .start = start
111- self .end = end
112-
113- def size (self ):
114- return self .end - self .start
115-
116- def __repr__ (self ):
117- return f"Block(start={ self .start } , end={ self .end } )"
118-
119- def can_merge (self , block : "MemoryBlock" ):
120- return (self .start == block .end ) or (block .start == self .end )
121-
122-
123- class MemoryManager :
124- def __init__ (self , total_size ):
125- """
126- 初始化内存管理器
127- :param total_size: 总内存大小
128- """
129- self .total_size = total_size
130- self .mem_set_by_start = SortedSet (key = lambda x : (x .start , x .size ()))
131- self .mem_set_by_size = SortedSet (key = lambda x : (x .size (), x .start ))
132- total = MemoryBlock (0 , total_size )
133- self .__add (total )
134-
135- def alloc (self , need_size : int ) -> Optional [MemoryBlock ]:
136- assert need_size > 0
137-
138- if len (self .mem_set_by_size ) == 0 :
139- return None
140-
141- key = MemoryBlock (start = - 1 , end = - 1 + need_size )
142- find_index = self .mem_set_by_size .bisect_left (key )
143- if find_index < len (self .mem_set_by_size ):
144- finded_mem_block : MemoryBlock = self .mem_set_by_size [find_index ]
145- self .__del (finded_mem_block )
146- ret_mem_block = MemoryBlock (
147- start = finded_mem_block .start ,
148- end = finded_mem_block .start + need_size ,
149- )
150- left_block = MemoryBlock (
151- start = finded_mem_block .start + need_size ,
152- end = finded_mem_block .end ,
153- )
154- if left_block .size () > 0 :
155- self .__add (left_block )
156-
157- return ret_mem_block
158- else :
159- return None
160-
161- def release (self , block : MemoryBlock ):
162- if block is None :
163- return
164- if len (self .mem_set_by_size ) == 0 :
165- self .__add (block )
166- return
167-
168- finded_index = self .mem_set_by_start .bisect_left (block )
169- for index in [finded_index - 1 , finded_index , finded_index + 1 ]:
170- if index < len (self .mem_set_by_start ):
171- sub_block : MemoryBlock = self .mem_set_by_start [index ]
172- # merge
173- if block .can_merge (sub_block ):
174- self .__del (sub_block )
175- merge_block = MemoryBlock (
176- start = min (block .start , sub_block .start ),
177- end = max (block .end , sub_block .end ),
178- )
179- self .release (merge_block )
180- return
181- # 无法merge时,直接add
182- self .__add (block )
183- return
184-
185- def __add (self , block ):
186- self .mem_set_by_start .add (block )
187- self .mem_set_by_size .add (block )
188-
189- def __del (self , block ):
190- self .mem_set_by_start .remove (block )
191- self .mem_set_by_size .remove (block )
19272
19373
19474if __name__ == "__main__" :
0 commit comments