Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lightllm/common/cpu_cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .tensor_backend import CpuCacheTensorBackend, CpuCacheTensorSpec

__all__ = ["CpuCacheTensorBackend", "CpuCacheTensorSpec"]
48 changes: 48 additions & 0 deletions lightllm/common/cpu_cache/tensor_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import ctypes
from dataclasses import dataclass
from typing import Optional, Tuple

import numpy as np
import torch

from lightllm.utils.kv_cache_utils import attach_shm_kv_cache_ptr, create_shm_kv_cache_ptr, register_shm_ptr_to_pin


@dataclass(frozen=True)
class CpuCacheTensorSpec:
shm_key: int
shape: Tuple[int, ...]
dtype: torch.dtype
size_bytes: int
pin_on_create: bool = True


class CpuCacheTensorBackend:
def __init__(self, tensor_spec: CpuCacheTensorSpec):
self.tensor_spec = tensor_spec

def create_or_attach(
self, init_shm_data: bool, wait_for_register: bool = True
) -> Tuple[torch.Tensor, Optional[object]]:
return self._init_tensor(create_mode=init_shm_data, wait_for_register=wait_for_register)

def _init_tensor(self, create_mode: bool, wait_for_register: bool) -> Tuple[torch.Tensor, Optional[object]]:
if create_mode:
shm_ptr = create_shm_kv_cache_ptr(key=self.tensor_spec.shm_key, size=self.tensor_spec.size_bytes)
else:
shm_ptr = attach_shm_kv_cache_ptr(key=self.tensor_spec.shm_key, size=self.tensor_spec.size_bytes)
attach_handle = None
if not create_mode or self.tensor_spec.pin_on_create:
attach_handle = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.tensor_spec.size_bytes)
if wait_for_register:
attach_handle.wait()
cpu_cache_tensor = self._build_tensor_view(shm_ptr=shm_ptr)
assert shm_ptr == cpu_cache_tensor.data_ptr()
return cpu_cache_tensor, attach_handle

def _build_tensor_view(self, shm_ptr: int) -> torch.Tensor:
numpy_array = np.frombuffer(
memoryview((ctypes.c_uint8 * self.tensor_spec.size_bytes).from_address(shm_ptr)),
dtype=np.uint8,
)
return torch.from_numpy(numpy_array).view(dtype=self.tensor_spec.dtype).view(self.tensor_spec.shape)
91 changes: 91 additions & 0 deletions lightllm/server/embed_cache/allocator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Optional

from sortedcontainers import SortedSet


class MemoryBlock:
"""内存块类,表示一个连续的内存区域"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the rest of the codebase which is in English, it would be better to write docstrings and comments in English. This improves maintainability for a wider audience.

Suggested change
"""内存块类,表示一个连续的内存区域"""
"""A memory block, representing a contiguous memory region."""


def __init__(self, start, end):
self.start = start
self.end = end

def size(self):
return self.end - self.start

def __repr__(self):
return f"Block(start={self.start}, end={self.end})"

def can_merge(self, block: "MemoryBlock"):
return (self.start == block.end) or (block.start == self.end)


class MemoryManager:
def __init__(self, total_size):
"""
初始化内存管理器
:param total_size: 总内存大小
"""
Comment on lines +25 to +28
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the rest of the codebase which is in English, it would be better to write docstrings and comments in English. This improves maintainability for a wider audience.

Suggested change
"""
初始化内存管理器
:param total_size: 总内存大小
"""
"""
Initializes the memory manager.
:param total_size: The total size of memory to manage.
"""

self.total_size = total_size
self.mem_set_by_start = SortedSet(key=lambda x: (x.start, x.size()))
self.mem_set_by_size = SortedSet(key=lambda x: (x.size(), x.start))
total = MemoryBlock(0, total_size)
self.__add(total)

def alloc(self, need_size: int) -> Optional[MemoryBlock]:
assert need_size > 0

if len(self.mem_set_by_size) == 0:
return None

key = MemoryBlock(start=-1, end=-1 + need_size)
find_index = self.mem_set_by_size.bisect_left(key)
if find_index < len(self.mem_set_by_size):
finded_mem_block: MemoryBlock = self.mem_set_by_size[find_index]
self.__del(finded_mem_block)
ret_mem_block = MemoryBlock(
start=finded_mem_block.start,
end=finded_mem_block.start + need_size,
)
left_block = MemoryBlock(
start=finded_mem_block.start + need_size,
end=finded_mem_block.end,
)
if left_block.size() > 0:
self.__add(left_block)

return ret_mem_block
else:
return None

def release(self, block: MemoryBlock):
if block is None:
return
if len(self.mem_set_by_size) == 0:
self.__add(block)
return

finded_index = self.mem_set_by_start.bisect_left(block)
for index in [finded_index - 1, finded_index, finded_index + 1]:
if index < len(self.mem_set_by_start):
Comment on lines +69 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for finding adjacent blocks to merge has a bug and can be optimized.

  1. Bug: The condition if index < len(self.mem_set_by_start) is incorrect when finded_index is 0. In that case, index becomes -1, and self.mem_set_by_start[-1] in Python accesses the last element of the list, which is not the intended neighbor. This can lead to incorrect merges.
  2. Optimization: The loop for index in [finded_index - 1, finded_index, finded_index + 1] checks more indices than necessary. Since mem_set_by_start is sorted by start address, you only need to check the block immediately before (finded_index - 1) and at (finded_index) the insertion point of the released block.
Suggested change
for index in [finded_index - 1, finded_index, finded_index + 1]:
if index < len(self.mem_set_by_start):
for index in [finded_index - 1, finded_index]:
if 0 <= index < len(self.mem_set_by_start):

sub_block: MemoryBlock = self.mem_set_by_start[index]
# merge
if block.can_merge(sub_block):
self.__del(sub_block)
merge_block = MemoryBlock(
start=min(block.start, sub_block.start),
end=max(block.end, sub_block.end),
)
self.release(merge_block)
return
# 无法merge时,直接add
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the rest of the codebase which is in English, it would be better to write comments in English. This improves maintainability for a wider audience.

Suggested change
# 无法merge时,直接add
# If it can't be merged, add it directly

self.__add(block)
return

def __add(self, block):
self.mem_set_by_start.add(block)
self.mem_set_by_size.add(block)

def __del(self, block):
self.mem_set_by_start.remove(block)
self.mem_set_by_size.remove(block)
160 changes: 20 additions & 140 deletions lightllm/server/embed_cache/embed_cache_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import ctypes
import torch
import numpy as np
from sortedcontainers import SortedSet
from typing import Optional, List
from typing import Optional
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.log_utils import init_logger
from lightllm.utils.embed_utils import calcu_embed_cache_meta
from lightllm.utils.kv_cache_utils import create_shm_kv_cache_ptr, attach_shm_kv_cache_ptr, register_shm_ptr_to_pin
from lightllm.common.cpu_cache import CpuCacheTensorBackend, CpuCacheTensorSpec
from .allocator import MemoryBlock, MemoryManager
from .copy_to_cache import offload_embed_tensor_to_cache

logger = init_logger(__name__)

Expand All @@ -25,10 +24,11 @@ def __init__(self, create_meta_data: bool, init_shm_data: bool):
if create_meta_data:
self.token_index_manager = MemoryManager(total_size=self.token_num)
else:
if init_shm_data:
self._create_shm_embed_kv_cache()
else:
self._attach_shm_cpu_embed_cache()
tensor_backend = CpuCacheTensorBackend(tensor_spec=self._build_tensor_spec())
self.cpu_embed_cache_tensor, _ = tensor_backend.create_or_attach(
init_shm_data=init_shm_data,
wait_for_register=True,
)
return

def alloc_indexes(self, token_num: int) -> Optional["MemoryBlock"]:
Expand All @@ -39,17 +39,14 @@ def release_indexes(self, block: "MemoryBlock"):
return

def copy_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache: int):
from .copy_to_cache import offload_embed_tensor_to_cache

offload_embed_tensor_to_cache(
embed_tensor=embed_tensor,
cache_tensor=self.cpu_embed_cache_tensor,
start_index_in_cache=start_index_in_cache,
)
return

def copy_vision_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache: int):
from .copy_to_cache import offload_embed_tensor_to_cache

if embed_tensor.ndim == 3:
# check for qwen3 vision embed tensor shape, use apply deepstack
assert embed_tensor.shape[1] == self.cpu_embed_cache_tensor.shape[1]
Expand All @@ -59,136 +56,19 @@ def copy_vision_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache:
cache_tensor=self.cpu_embed_cache_tensor,
start_index_in_cache=start_index_in_cache,
)

def _create_shm_embed_kv_cache(self):
shm_ptr = create_shm_kv_cache_ptr(
key=self.args.multi_modal_cache_shm_id, size=self.embed_cache_tensor_meta.calcu_size()
)
handle = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.embed_cache_tensor_meta.calcu_size())
handle.wait()
numpy_array = np.frombuffer(
memoryview((ctypes.c_uint8 * self.embed_cache_tensor_meta.calcu_size()).from_address(shm_ptr)),
dtype=np.uint8,
)
# 将 NumPy 数组转换为 PyTorch 张量
shape = (
self.embed_cache_tensor_meta.token_num,
self.embed_cache_tensor_meta.layer_num,
self.embed_cache_tensor_meta.hidden_size,
)
self.cpu_embed_cache_tensor = (
torch.from_numpy(numpy_array).view(dtype=self.embed_cache_tensor_meta.data_type).view(shape)
)
return

def _attach_shm_cpu_embed_cache(self):
shm_ptr = attach_shm_kv_cache_ptr(
key=self.args.multi_modal_cache_shm_id, size=self.embed_cache_tensor_meta.calcu_size()
)
handle = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.embed_cache_tensor_meta.calcu_size())
handle.wait()
numpy_array = np.frombuffer(
memoryview((ctypes.c_uint8 * self.embed_cache_tensor_meta.calcu_size()).from_address(shm_ptr)),
dtype=np.uint8,
)
shape = (
self.embed_cache_tensor_meta.token_num,
self.embed_cache_tensor_meta.layer_num,
self.embed_cache_tensor_meta.hidden_size,
)
self.cpu_embed_cache_tensor = (
torch.from_numpy(numpy_array).view(dtype=self.embed_cache_tensor_meta.data_type).view(shape)
def _build_tensor_spec(self) -> CpuCacheTensorSpec:
return CpuCacheTensorSpec(
shm_key=self.args.multi_modal_cache_shm_id,
shape=(
self.embed_cache_tensor_meta.token_num,
self.embed_cache_tensor_meta.layer_num,
self.embed_cache_tensor_meta.hidden_size,
),
dtype=self.embed_cache_tensor_meta.data_type,
size_bytes=self.embed_cache_tensor_meta.calcu_size(),
)
assert shm_ptr == self.cpu_embed_cache_tensor.data_ptr()
return None


class MemoryBlock:
"""内存块类,表示一个连续的内存区域"""

def __init__(self, start, end):
self.start = start
self.end = end

def size(self):
return self.end - self.start

def __repr__(self):
return f"Block(start={self.start}, end={self.end})"

def can_merge(self, block: "MemoryBlock"):
return (self.start == block.end) or (block.start == self.end)


class MemoryManager:
def __init__(self, total_size):
"""
初始化内存管理器
:param total_size: 总内存大小
"""
self.total_size = total_size
self.mem_set_by_start = SortedSet(key=lambda x: (x.start, x.size()))
self.mem_set_by_size = SortedSet(key=lambda x: (x.size(), x.start))
total = MemoryBlock(0, total_size)
self.__add(total)

def alloc(self, need_size: int) -> Optional[MemoryBlock]:
assert need_size > 0

if len(self.mem_set_by_size) == 0:
return None

key = MemoryBlock(start=-1, end=-1 + need_size)
find_index = self.mem_set_by_size.bisect_left(key)
if find_index < len(self.mem_set_by_size):
finded_mem_block: MemoryBlock = self.mem_set_by_size[find_index]
self.__del(finded_mem_block)
ret_mem_block = MemoryBlock(
start=finded_mem_block.start,
end=finded_mem_block.start + need_size,
)
left_block = MemoryBlock(
start=finded_mem_block.start + need_size,
end=finded_mem_block.end,
)
if left_block.size() > 0:
self.__add(left_block)

return ret_mem_block
else:
return None

def release(self, block: MemoryBlock):
if block is None:
return
if len(self.mem_set_by_size) == 0:
self.__add(block)
return

finded_index = self.mem_set_by_start.bisect_left(block)
for index in [finded_index - 1, finded_index, finded_index + 1]:
if index < len(self.mem_set_by_start):
sub_block: MemoryBlock = self.mem_set_by_start[index]
# merge
if block.can_merge(sub_block):
self.__del(sub_block)
merge_block = MemoryBlock(
start=min(block.start, sub_block.start),
end=max(block.end, sub_block.end),
)
self.release(merge_block)
return
# 无法merge时,直接add
self.__add(block)
return

def __add(self, block):
self.mem_set_by_start.add(block)
self.mem_set_by_size.add(block)

def __del(self, block):
self.mem_set_by_start.remove(block)
self.mem_set_by_size.remove(block)


if __name__ == "__main__":
Expand Down
5 changes: 4 additions & 1 deletion lightllm/server/embed_cache/impl/naive_memory_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
import multiprocessing.shared_memory as shm
from ..utils import get_shm_name_data, free_shm
from lightllm.utils.log_utils import init_logger
from ..embed_cache_client import CpuEmbedCacheClient, MemoryBlock, SortedSet
from sortedcontainers import SortedSet

from ..allocator import MemoryBlock
from ..embed_cache_client import CpuEmbedCacheClient

logger = init_logger(__name__)

Expand Down
Loading
Loading