Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions fastdeploy/cache_manager/cache_messager.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,12 +613,16 @@ def __init__(
)

self.gpu_id = gpu_id
self.cache_info = dict()
self.cache_info = dict() # {'request_id': cache_info_dict}
self.rank_id = self.rank + local_data_parallel_id * self.nranks
self.engine_cache_task_thread_lock = threading.Lock()
self.engine_cache_tasks = [dict() for _ in range(512)]
self.idx_cache_task_dict = {}
self.cache_prefilled_engine_ids_queue = queue.Queue() # keep batch slot index for each prefill step
self.engine_cache_tasks = [
dict() for _ in range(512)
] # {'layer_id': {'prefilled_layer_idx': xx, 'prefilled_block_num': xx}}
self.idx_cache_task_dict = {} # {'slot_idx': cache_info_dict}
self.cache_prefilled_engine_ids_queue = (
queue.Queue()
) # [(slot_idx1, prefilled_token_num1), (slot_idx2, prefilled_token_num2)]
if splitwise_role == "prefill":
consume_signals_thread = threading.Thread(target=self.consume_signals)
consume_signals_thread.daemon = True
Expand All @@ -638,7 +642,6 @@ def _add_cache_task_thread(self):
while True:
try:
cache_info = self.engine_worker_queue.get_cache_info()
finished_add_cache_task_req_ids = []
if cache_info:
logger.debug(f"Get cache info from engine worker queue, {cache_info}")
self.engine_worker_queue.cache_info_barrier.wait()
Expand All @@ -647,7 +650,6 @@ def _add_cache_task_thread(self):
self.cache_info[info["request_id"]].update(info)
current_info = self.cache_info[info["request_id"]]
assert "dest_block_ids" in current_info and "src_block_ids" in current_info
finished_add_cache_task_req_ids.append(info["request_id"])
decode_cached_block_num = len(current_info["src_block_ids"]) - len(
current_info["dest_block_ids"]
)
Expand All @@ -659,17 +661,13 @@ def _add_cache_task_thread(self):
current_info["sended_layer_id"] = -1
current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size
current_info["status"] = "init"
logger.info(f"Get cache info from D: finish add cache task: {current_info}")
logger.info(f"Get cache info and finish add cache task: {current_info}")
self.cache_info[info["request_id"]] = current_info
self.idx_cache_task_dict[current_info["current_id"]] = current_info
else:
logger.info(f"Get cache info from P: {info}")
logger.info(f"Get cache info: {info}")
self.cache_info[info["request_id"]] = info

if finished_add_cache_task_req_ids:
logger.info(f"Put processed tasks into engine worker queue: {finished_add_cache_task_req_ids}")
self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids)
self.engine_worker_queue.finish_add_cache_task_barrier.wait()
else:
time.sleep(0.001)
except Exception as e:
Expand All @@ -687,10 +685,12 @@ def prefill_layerwise_send_cache_thread(self):
block_start_end_list = []
current_prefilled_token_num_list = []
for engine_index, current_step_prefilled_token_num in batch_engine_signals:
self._maybe_wait_for_cache_task(engine_index)
assert (
engine_index in self.idx_cache_task_dict
), f"engine_index {engine_index} not in self.idx_cache_task_dict {self.idx_cache_task_dict}"
block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"]

prefilled_token_num = current_step_prefilled_token_num
if (
prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"]
Expand Down Expand Up @@ -917,6 +917,20 @@ def _handle_connect_task(self):
except Exception as e:
logger.error(f"handle_connect_task has exception: {e}, {traceback.format_exc()}")

def _maybe_wait_for_cache_task(self, engine_index):
# If cache messager does not get cache task from engine, just hang here for now
wait_step = 1
sleep_seconds = 0.005

while engine_index not in self.idx_cache_task_dict:

This comment was marked as outdated.

This comment was marked as outdated.

Comment thread
juncaipeng marked this conversation as resolved.
time.sleep(sleep_seconds)
wait_step += 1

if wait_step % 400 == 0:
logger.warning(
Comment thread
juncaipeng marked this conversation as resolved.
f"waiting cache task for engine_index: {engine_index}, cost_time: {wait_step * 0.005:.2f} s"
Comment thread
juncaipeng marked this conversation as resolved.
)

Comment thread
juncaipeng marked this conversation as resolved.

def main():
device = args.device_id
Expand Down
218 changes: 13 additions & 205 deletions fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import time
import traceback
import weakref
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Dict, List, Optional, Tuple

Expand All @@ -42,6 +41,7 @@
import fastdeploy.metrics.trace as tracing
from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.config import FDConfig
from fastdeploy.engine.common_engine_prepare_mixin import EngineServicePrepareMixin
from fastdeploy.engine.register_manager import RegisterManager
from fastdeploy.engine.request import (
CompletionOutput,
Expand Down Expand Up @@ -115,7 +115,7 @@ def _format_worker_launch_failure_message(log_dir: str) -> str:
return message


class EngineService:
class EngineService(EngineServicePrepareMixin):
"""
Base class containing common engine functionality
"""
Expand Down Expand Up @@ -251,12 +251,13 @@ def start(self, async_llm_pid=None):
self.start_worker_service(async_llm_pid)

if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.insert_task_to_worker_thread = threading.Thread(
target=self._schedule_request_to_worker_v1, daemon=True
)
self.prepare_request_thread = threading.Thread(target=self._prepare_request_v1, daemon=True)
self.prepare_request_thread.start()
self.schedule_request_thread = threading.Thread(target=self._schedule_request_to_worker_v1, daemon=True)
self.schedule_request_thread.start()
else:
self.insert_task_to_worker_thread = threading.Thread(target=self._schedule_request_to_worker, daemon=True)
self.insert_task_to_worker_thread.start()
self.schedule_request_thread = threading.Thread(target=self._schedule_request_to_worker, daemon=True)
self.schedule_request_thread.start()
self.token_processor.tasks_queue = self.engine_worker_queue
self.token_processor.run()
if self.cfg.scheduler_config.splitwise_role == "decode":
Expand Down Expand Up @@ -879,215 +880,19 @@ def _schedule_request_to_worker_v1(self):
Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1).
"""
tracing.trace_set_thread_info("Scheduler Task to Work")
get_request_pool = ThreadPoolExecutor(max_workers=1)
is_fetching = False

def _fetch_request():
try:
with self._pause_cond:
self._pause_cond.wait_for(lambda: not self.is_paused)
nonlocal is_fetching
num_prefill_batch = min(
int(self.resource_manager.available_batch()),
self.cfg.max_prefill_batch,
)

if self.cfg.scheduler_config.splitwise_role != "mixed":
max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens
else:
max_num_batched_tokens = self.cfg.model_config.max_model_len

available_blocks = self.cfg.cache_config.max_block_num_per_seq
tasks = self.scheduler.get_requests(
available_blocks=available_blocks,
block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=0, # self.cfg.cache_config.enc_dec_block_num
max_num_batched_tokens=max_num_batched_tokens,
batch=num_prefill_batch,
)
for task in tasks:
task.metrics.engine_get_req_time = time.time()
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))

if self.cfg.scheduler_config.splitwise_role == "decode":
# TODO: refine scheduler to remove this limitation
# Decode will process and schedule the request sent by prefill to engine,
# so the same request sent by the decode api server will be ignored
is_fetching = False
return

if tasks:
self.llm_logger.debug(
f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}"
)

if self.cfg.scheduler_config.splitwise_role == "prefill":
for task in tasks:
# start async preprocess
self.resource_manager.apply_async_preprocess(task)
need_delete_tasks = []
if envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES:
for task in tasks:
# assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task):
time.sleep(0.005)
self.llm_logger.debug(
f"P has allocated resources and then ask D resource for request: {task.request_id}"
)
trace_print(
LoggingEventName.ASK_DECODE_RESOURCE_START, task.request_id, getattr(task, "user", "")
)
task.metrics.ask_decode_resource_start_time = time.time()
while True:
self.split_connector.send_splitwise_tasks([task], task.idx)
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.warning(
f"D failed to allocate resource for request {task.request_id}, try again."
)
time.sleep(0.05)
else:
task.metrics.ask_decode_resource_finish_time = time.time()
trace_print(
LoggingEventName.ASK_DECODE_RESOURCE_END,
task.request_id,
getattr(task, "user", ""),
)
break
self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}")
else:
for task in tasks:
# assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task):
time.sleep(0.005)

self.llm_logger.debug(
f"P has allocated resources and then ask D resource for req_id: {task.request_id}"
)
trace_print(
LoggingEventName.ASK_DECODE_RESOURCE_START, task.request_id, getattr(task, "user", "")
)
task.metrics.ask_decode_resource_start_time = time.time()
self.split_connector.send_splitwise_tasks([task], task.idx)

for task in tasks:
# assure fetch block ids from D
status, msg = self.split_connector.check_decode_allocated(task)
task.metrics.ask_decode_resource_finish_time = time.time()
trace_print(
LoggingEventName.ASK_DECODE_RESOURCE_END, task.request_id, getattr(task, "user", "")
)
if not status:
error_msg = (
f"PD Error: prefill failed to apply for resource from decode, "
f"req: {task.request_id}, msg:{msg}."
)
self.llm_logger.error(error_msg)
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=error_msg,
)
]
)
main_process_metrics.reschedule_req_num.inc()
need_delete_tasks.append(task)
continue
for tmp_task in need_delete_tasks:
tasks.remove(tmp_task)
# release resource in P
self.resource_manager.pre_recycle_resource(tmp_task.request_id)

# to send cache info to cache messager
if tasks:
need_check_req_ids = [task.request_id for task in tasks]
self.split_connector.send_cache_info_to_messager(tasks, 0)
# ensure cache tasks has sent to cache_messager
need_check_req_ids = [task.request_id for task in tasks]
finished_ids, delete_tasks_list = [], []
while need_check_req_ids:
finished_ids.extend(self.engine_worker_queue.get_finished_add_cache_task_req())
self.llm_logger.debug(
f"P has successfully sent cache infos to cache messager for requests: {finished_ids}"
)
if finished_ids:
for task in tasks:
result = self.resource_manager.waiting_async_process(task)
if result is None:
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=task.error_code,
error_msg=task.error_message,
)
]
)
need_check_req_ids.remove(task.request_id)
delete_tasks_list.append(task)
elif result is False:
if task.request_id in finished_ids:
need_check_req_ids.remove(task.request_id)
finished_ids.remove(task.request_id)
else:
time.sleep(0.001)

for tmp_task in delete_tasks_list:
tasks.remove(tmp_task)
# release resource in P
self.resource_manager.pre_recycle_resource(tmp_task.request_id)

# Fetch requests and add them to the scheduling queue
if tasks:
for task in tasks:
task.metrics.add_req_to_resource_manager_time = time.time()
trace_print(
LoggingEventName.RESOURCE_ALLOCATE_START, task.request_id, getattr(task, "user", "")
)
if self.cfg.scheduler_config.splitwise_role == "prefill":
self.resource_manager.add_request_in_p(tasks)
self.llm_logger.info(
f"P add requests into running queue: {[task.request_id for task in tasks]}"
)
else:
for task in tasks:
self.resource_manager.add_request(task)
is_fetching = False
except Exception as e:
self.llm_logger.error(f"fetching request error {e} {str(traceback.format_exc())}")
is_fetching = False

while self.running:
with self._pause_cond:
self._pause_cond.wait_for(lambda: not self.is_paused)

try:
if self.engine_worker_queue.exist_tasks():
time.sleep(0.001)
continue
if self.cfg.scheduler_config.splitwise_role != "mixed":
if not is_fetching:
is_fetching = True
get_request_pool.submit(_fetch_request)

else:
if len(self.resource_manager.waiting) == 0 and (not is_fetching):
# Check if the thread pool is still available to avoid submitting tasks to a shutdown thread pool.
try:
is_fetching = True
get_request_pool.submit(_fetch_request)
except RuntimeError as e:
if "shutdown" in str(e):
self.llm_logger.info("Thread pool shutdown detected, exiting scheduler loop")
break
else:
raise

if hasattr(self.resource_manager, "scheduler_unhandled_request_num"):
self.resource_manager.scheduler_unhandled_request_num = self._get_scheduler_unhandled_request_num()

# 2. Schedule requests
tasks, error_tasks = self.resource_manager.schedule()

Expand Down Expand Up @@ -2228,6 +2033,9 @@ def _exit_sub_services(self):
self.llm_logger.info("Exit sub services.....")
self.running = False

if hasattr(self, "_fetch_pool"):
self._fetch_pool.shutdown(wait=False)

if self.use_async_llm:
# Clean up worker processes first (before closing multiprocessing services)
if hasattr(self, "worker_proc") and self.worker_proc is not None:
Expand Down
Loading
Loading