Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions fastdeploy/cache_manager/cache_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,14 @@ def _run_write_back_storage(
return block_num

elif self.storage_backend_type == "attention_store":
try:
if (self.rank == 0) and self.storage_backend_type == "attention_store":
self.storage_backend.flush_token_index(task_id, token_ids, 0, False)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug flush_token_index 现在在数据写入之前执行,存在索引先于实际数据可用的时序窗口风险。

原来此调用位于 write_back_storage_taskfinally 块中,保证在整个写入流程(无论成功与否)完成后才刷新索引。移到写入之前意味着:若后续 sdk.write 失败,storage 中索引已更新但数据未写入,导致其他请求命中索引却读不到有效 KV Cache。

请确认此改动是否为 XPU 特定需求,若是,建议加上平台或后端判断:

if not current_platform.is_xpu():
    # flush after write (original behavior)

或在 PR 中说明新时序的设计意图。

logger.info(f"Report cache index out HBM to cache storage for task {task_id}")
except Exception as e:
logger.info(
f"Failed to report cache index out HBM to cache storage for task {task_id}, error: {e}"
)
key_cache = []
val_cache = []
for i in range(self.num_layers + self.num_extra_layers):
Expand Down Expand Up @@ -1040,15 +1048,6 @@ def write_back_storage_task(self, task: WriteStorageTask):
except Exception as e:
logger.error(f"Error in write back storage task: {e}, traceback:{traceback.format_exc()}")
gpu_block_ids = []
finally:
try:
if (self.rank == 0) and self.storage_backend_type == "attention_store":
self.storage_backend.flush_token_index(task.task_id, task.token_ids, 0, False)
logger.info(f"Report cache index out HBM to cache storage for task {task.task_id}")
except Exception as e:
logger.info(
f"Failed to report cache index out HBM to cache storage for task {task.task_id}, error: {e}"
)

result = (CacheStatus.GPU2STORAGE, task.task_id, task.keys, gpu_block_ids)
self.cache_task_queue.swap_to_storage_barrier.wait()
Expand Down
115 changes: 63 additions & 52 deletions fastdeploy/cache_manager/prefix_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,58 +831,71 @@ def request_match_blocks(self, task: Request, block_size, *args):

if self.kvcache_storage_backend and no_match_token_num >= block_size and not envs.FD_AS_ONLY_FLUSH:
if not self.can_allocate_gpu_blocks(num_blocks=no_match_block_num, try_free_gpu_blocks=False):
raise Exception(
"request_match_blocks: Not enough GPU memory to allocate cache for matched Storage Cache"
logger.warning(
"request_match_blocks: skip storage cache prefetch because GPU blocks are insufficient, "
f"req_id {req_id}, need {no_match_block_num}, free {len(self.gpu_free_block_list)}"
)

logger.debug(
f"request_match_blocks: req_id {req_id}, allocate {no_match_block_num} block to receive storage cache"
)
gpu_recv_storage_block_ids = self.allocate_gpu_blocks(no_match_block_num)

prefix_block_key = [] if match_block_node.hash_value is None else [match_block_node.hash_value]
cur_token_idx = match_token_num
no_match_block_keys = []
mm_idx = 0
while cur_token_idx <= input_token_num - block_size:
cur_block_token_ids = input_token_ids[cur_token_idx : cur_token_idx + block_size]
# Get extra hash keys for multimodal content (images, videos, etc.)
mm_idx, extra_keys = self.get_block_hash_extra_keys(
request=task,
start_idx=cur_token_idx,
end_idx=cur_token_idx + block_size,
mm_idx=mm_idx,
else:
logger.debug(
f"request_match_blocks: req_id {req_id}, allocate {no_match_block_num} block to receive storage cache"
)
prefix_block_key.extend(extra_keys)
cur_block_key = get_hash_str(cur_block_token_ids, prefix_block_key)
no_match_block_keys.append(cur_block_key)
cur_token_idx += block_size
prefix_block_key = [cur_block_key]
gpu_recv_storage_block_ids = self.allocate_gpu_blocks(no_match_block_num)

prefix_block_key = [] if match_block_node.hash_value is None else [match_block_node.hash_value]
cur_token_idx = match_token_num
no_match_block_keys = []
mm_idx = 0
while cur_token_idx <= input_token_num - block_size:
cur_block_token_ids = input_token_ids[cur_token_idx : cur_token_idx + block_size]
# Get extra hash keys for multimodal content (images, videos, etc.)
mm_idx, extra_keys = self.get_block_hash_extra_keys(
request=task,
start_idx=cur_token_idx,
end_idx=cur_token_idx + block_size,
mm_idx=mm_idx,
)
prefix_block_key.extend(extra_keys)
cur_block_key = get_hash_str(cur_block_token_ids, prefix_block_key)
no_match_block_keys.append(cur_block_key)
cur_token_idx += block_size
prefix_block_key = [cur_block_key]

logger.info(
f"start prefetch cache from storage, req_id: {req_id}, block num: {len(no_match_block_keys)}"
)
start_time = time.time()
read_storage_task = ReadStorageTask(
task_id=req_id,
keys=no_match_block_keys,
token_ids=input_token_ids if self.kvcache_storage_backend == "attention_store" else None,
gpu_block_ids=gpu_recv_storage_block_ids,
start_read_block_idx=match_token_num // block_size,
)
logger.debug(f"issue read storage task: {read_storage_task}")
storage_matched_block_ids = self.issue_prefetch_storage_task(read_storage_task)
storage_matched_block_num = len(storage_matched_block_ids)
storage_match_token_num = storage_matched_block_num * block_size
cost_time = time.time() - start_time
metrics["storage_cache_prepare_time"] = cost_time
logger.info(
f"finish prefetch cache from storage, req_id: {req_id}, "
f"matched block num: {storage_matched_block_num}, cost_time:{cost_time:.6f}s"
)
try:
logger.info(
f"start prefetch cache from storage, req_id: {req_id}, block num: {len(no_match_block_keys)}"
)
start_time = time.time()
read_storage_task = ReadStorageTask(
task_id=req_id,
keys=no_match_block_keys,
token_ids=(
input_token_ids if self.kvcache_storage_backend == "attention_store" else None
),
gpu_block_ids=gpu_recv_storage_block_ids,
start_read_block_idx=match_token_num // block_size,
)
logger.debug(f"issue read storage task: {read_storage_task}")
storage_matched_block_ids = self.issue_prefetch_storage_task(read_storage_task)
storage_matched_block_num = len(storage_matched_block_ids)
storage_match_token_num = storage_matched_block_num * block_size
cost_time = time.time() - start_time
metrics["storage_cache_prepare_time"] = cost_time
logger.info(
f"finish prefetch cache from storage, req_id: {req_id}, "
f"matched block num: {storage_matched_block_num}, cost_time:{cost_time:.6f}s"
)

match_storage_block_ids = gpu_recv_storage_block_ids[:storage_matched_block_num]
self.recycle_gpu_blocks(gpu_recv_storage_block_ids[storage_matched_block_num:])
match_storage_block_ids = gpu_recv_storage_block_ids[:storage_matched_block_num]
self.recycle_gpu_blocks(gpu_recv_storage_block_ids[storage_matched_block_num:])
except Exception as e:
logger.warning(
"request_match_blocks: storage cache prefetch failed, fallback to cache miss, "
f"req_id {req_id}, error: {type(e)} {e}"
)
self.recycle_gpu_blocks(gpu_recv_storage_block_ids, req_id)
gpu_recv_storage_block_ids = []
storage_match_token_num = 0
match_storage_block_ids = []

# 4. update metrics
match_token_num = gpu_match_token_num + cpu_match_token_num + storage_match_token_num
Expand Down Expand Up @@ -1127,10 +1140,7 @@ def write_cache_to_storage(self, request: Request):
if isinstance(token_ids, np.ndarray):
token_ids = token_ids.tolist()

if self.config.cache_config.enable_output_caching:
input_token_ids = token_ids + request.output_token_ids
else:
input_token_ids = token_ids
input_token_ids = token_ids + request.output_token_ids
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug 移除 enable_output_caching 条件判断后,无论该 flag 是否开启,output_token_ids 都会被拼接到 token_ids 中参与存储写入。

enable_output_caching=False 时,原来只写入 input token 对应的 KV Cache;现在会额外包含 output token,导致写入数据与 prefix cache 命中时的期望不匹配,可能造成 cache 污染或 token_ids 长度越界。

后续一行已增加截断 input_token_ids = input_token_ids[: len(keys) * block_size],但这只能防止越界,无法还原原有的语义区分。请确认此变更是否为预期行为。


req_id = request.request_id
keys = []
Expand All @@ -1144,6 +1154,7 @@ def write_cache_to_storage(self, request: Request):

trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_START, request.request_id, getattr(request, "user", ""))
gpu_block_ids = request.block_tables[: len(keys)]
input_token_ids = input_token_ids[: len(keys) * self.config.cache_config.block_size]
logger.info(f"start write cache back to storage, req_id: {req_id}, block num: {len(keys)}")
write_storage_task = WriteStorageTask(
task_id=req_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,42 +189,96 @@ def write(
start_write_block_idx: int,
timeout: float = 30.0,
) -> int:
logger.debug(
f"[WRITE BEGIN] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids} start_write_block_idx: {start_write_block_idx} timeout: {timeout}"
)
tokens = Tokens(token_ids, self.config.block_token_size)
k_data_ptrs = [k.data_ptr() for k in key_cache]
v_data_ptrs = [v.data_ptr() for v in val_cache]
num = 0
try:
if current_platform.is_cuda():
num = self.sdk.write(
list(range(self.config.layer_num)),
tokens,
start_write_block_idx,
k_data_ptrs,
v_data_ptrs,
gpu_block_ids,
timeout,
h2h_copy=False,
params=None,
layer_ids = list(range(self.config.layer_num))
block_token_size = self.config.block_token_size

total_timeout = float(os.getenv("AS_WRITE_TOTAL_TIMEOUT", str(timeout)))
slice_block_num = int(os.getenv("AS_WRITE_SLICE_BLOCK_NUM", "100"))
slice_timeout = float(os.getenv("AS_WRITE_SLICE_TIMEOUT", "10"))
logger.debug(
f"[WRITE BEGIN] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids}"
f" start_write_block_idx: {start_write_block_idx} timeout: {total_timeout}"
)
total_blocks = len(gpu_block_ids)
total_written = 0
overall_start = time.time()

for slice_start in range(0, total_blocks, slice_block_num):
elapsed = time.time() - overall_start
remaining_timeout = total_timeout - elapsed
if remaining_timeout <= 0:
logger.warning(
f"[WRITE TIMEOUT] task_id: {task_id} total timeout {total_timeout}s reached, "
f"written {total_written}/{total_blocks} blocks"
)
else:
num = self.sdk.write(
list(range(self.config.layer_num)),
tokens,
start_write_block_idx,
k_data_ptrs,
v_data_ptrs,
gpu_block_ids,
timeout,
break

slice_end = min(slice_start + slice_block_num, total_blocks)
slice_gpu_block_ids = gpu_block_ids[slice_start:slice_end]
slice_write_block_idx = start_write_block_idx + slice_start
slice_token_ids = token_ids[: (start_write_block_idx + slice_end) * block_token_size]
slice_tokens = Tokens(slice_token_ids, block_token_size)

logger.debug(
f"[WRITE SLICE BEGIN] task_id: {task_id} slice [{slice_start}:{slice_end}] "
f"block_idx={slice_write_block_idx}, blocks={len(slice_gpu_block_ids)}, "
f"token_ids_len={len(slice_token_ids)}, timeout={slice_timeout:.2f}s"
)
slice_start_time = time.time()
try:
if current_platform.is_cuda():
written = self.sdk.write(
layer_ids,
slice_tokens,
slice_write_block_idx,
k_data_ptrs,
v_data_ptrs,
slice_gpu_block_ids,
slice_timeout,
h2h_copy=False,
params=None,
)
else:
written = self.sdk.write(
layer_ids,
slice_tokens,
slice_write_block_idx,
k_data_ptrs,
v_data_ptrs,
slice_gpu_block_ids,
slice_timeout,
)
except AttentionStoreSDKError:
logger.error(
f"[WRITE ERROR] task_id: {task_id} slice [{slice_start}:{slice_end}], "
f"traceback:\n{traceback.format_exc()}"
)
logger.debug(f"[WRITE END] task_id: {task_id} written_blocks: {num}")
except AttentionStoreSDKError:
logger.error(
f"[WRITE ERROR] failed to execute sdk write, task_id: {task_id}, traceback:\n{traceback.format_exc()}"
written = 0
slice_cost = time.time() - slice_start_time
total_written += written

if written < len(slice_gpu_block_ids):
logger.warning(
f"[WRITE SLICE INCOMPLETE] task_id: {task_id} slice [{slice_start}:{slice_end}] "
f"({written}/{len(slice_gpu_block_ids)}), cost={slice_cost:.6f}s, "
f"total written {total_written}/{total_blocks}, "
f"prefix cache continuity broken, skip remaining slices"
)
break

logger.debug(
f"[WRITE SLICE END] task_id: {task_id} slice [{slice_start}:{slice_end}] "
f"written={written}, cost={slice_cost:.6f}s"
)
return num

total_cost = time.time() - overall_start
logger.info(
f"[WRITE END] task_id: {task_id} total cost={total_cost:.6f}s, "
f"written {total_written}/{total_blocks} blocks"
)
return total_written

def query(self, task_id: str, token_ids: List[int], start_match_block_idx: int, timeout: float = 10.0):
"""
Expand Down
7 changes: 5 additions & 2 deletions fastdeploy/model_executor/layers/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_le
top_k_padding = paddle.repeat_interleave(top_k[:real_bsz], repeats).unsqueeze(1)
topp_seed = paddle.repeat_interleave(infer_seed[:real_bsz], repeats).unsqueeze(1)

MAX_INFER_SEED = 9223372036854775806
if current_platform.is_xpu():
MAX_INFER_SEED = 2147483646
else:
MAX_INFER_SEED = 9223372036854775806

token_lens = paddle.where(
seq_lens_encoder[:real_bsz] == 0,
Expand All @@ -97,7 +100,7 @@ def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_le

offsets = paddle.where(
is_decoder,
local_pos * 4,
local_pos * 32,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug local_pos * 32 变更未做 XPU 平台门控,将影响所有平台(包括 CUDA/GPU)的采样 offset 计算。

同一函数中 MAX_INFER_SEED 的修改已正确使用 if current_platform.is_xpu() 门控,但此处缺少同样的保护。

如果 * 32 仅为 XPU 适配,建议修改为:

offsets = paddle.where(
    is_decoder,
    local_pos * 32 if current_platform.is_xpu() else local_pos * 4,
    paddle.zeros_like(local_pos),
)

如果该改动对所有平台都成立,请在 PR 描述中说明原因。

paddle.zeros_like(local_pos),
)

Expand Down
17 changes: 13 additions & 4 deletions fastdeploy/scheduler/local_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,14 @@ def _recycle(self, request_id: Optional[str] = None):
if request_id is not None:
self.requests.pop(request_id, None)
self.responses.pop(request_id, None)
self.ids.pop(self.ids.index(request_id))
self.ids_read_cursor -= 1
try:
idx = self.ids.index(request_id)
self.ids.pop(idx)
if idx < self.ids_read_cursor:
self.ids_read_cursor -= 1
except ValueError:
scheduler_logger.warning(f"_recycle error, request_id:{request_id} is not found in ids")
pass
return

if self.max_size <= 0:
Expand All @@ -148,10 +154,10 @@ def _recycle(self, request_id: Optional[str] = None):
break
expired_ids.append(request.request_id)

for i, expired_id in enumerate(expired_ids):
for expired_id in expired_ids:
self.requests.pop(expired_id, None)
self.responses.pop(expired_id, None)
self.ids.pop(i)
self.ids = self.ids[len(expired_ids) :]

if len(expired_ids) > 0:
if len(expired_ids) - 1 >= self.ids_read_cursor:
Expand Down Expand Up @@ -234,6 +240,9 @@ def calc_required_blocks(self, token_num, block_size):
return (token_num + block_size - 1) // block_size

def get_unhandled_request_num(self):
scheduler_logger.debug(
f"get_unhandled_request_num len(self.ids):{len(self.ids)}, self.ids_read_cursor:{self.ids_read_cursor}"
)
return len(self.ids) - self.ids_read_cursor

def get_requests(
Expand Down
9 changes: 9 additions & 0 deletions fastdeploy/splitwise/internal_adapter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@ def _recv_external_module_control_instruct(self):
with self.response_lock:
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)

elif task["cmd"] == "interrupt_requests":
self.engine.resource_manager.add_abort_req_ids(task["req_ids"])
result = {
"task_id": task_id_str,
"result": {"success": True, "interrupted_req_ids": task["req_ids"]},
}
with self.response_lock:
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)

except Exception as e:
logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}")

Expand Down