Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 192 additions & 24 deletions fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
import threading
import time
import traceback
import uuid
import weakref
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Tuple

import numpy as np
import paddle
import paddle.distributed as dist
import requests
import zmq
from tqdm import tqdm
Expand Down Expand Up @@ -298,15 +300,6 @@ def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进
create=True,
)

engine_forward_signal_data = np.zeros([1], dtype=np.int32)
self.engine_forward_signal = IPCSignal(
name="engine_forward_signal",
array=engine_forward_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)

# worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间
worker_healthy_live_recorded_time_array = np.zeros(
shape=[min(self.cfg.worker_num_per_node, self.cfg.parallel_config.tensor_parallel_size)], dtype=np.int32
Expand Down Expand Up @@ -375,6 +368,34 @@ def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进
create=True,
)

infer_finished_signal_data = np.zeros([1], dtype=np.int32)
self.infer_finished_signal = IPCSignal(
name="infer_finished_signal",
array=infer_finished_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)

def init_parallel_env(self, start_port=6070):
local_data_parallel_size = len(self.cfg.parallel_config.engine_worker_queue_port)
global_data_parallel_id = self.cfg.node_rank * local_data_parallel_size + self.cfg.parallel_config.local_data_parallel_id
os.environ["PADDLE_TRAINER_ID"] = str(global_data_parallel_id)
os.environ["PADDLE_TRAINERS_NUM"] = str(self.cfg.parallel_config.data_parallel_size)
if self.cfg.ips is None:
os.environ["PADDLE_TRAINER_ENDPOINTS"] = ','.join([f"0.0.0.0:{int(start_port + i)}" for i in range(local_data_parallel_size)])
else:
os.environ["PADDLE_TRAINER_ENDPOINTS"] = ','.join(
[f"{ip}:{int(start_port + i)}" for i in range(local_data_parallel_size) for ip in self.cfg.ips]
)
os.environ["PADDLE_DISTRI_BACKEND"] = "gloo"

dist.init_parallel_env()
# Avoid bringing this env variable to workers
os.unsetenv("PADDLE_DISTRI_BACKEND")

paddle.set_device("cpu")

def start_worker_queue_service(self, start_queue):
"""
start queue service for engine worker communication
Expand Down Expand Up @@ -796,6 +817,9 @@ def _schedule_request_to_worker_v1(self):
tracing.trace_set_thread_info("Scheduler Task to Work")
get_request_pool = ThreadPoolExecutor(max_workers=1)
is_fetching = False
buffered_req_info = {}
req_info_lock = threading.Lock()
last_sched_batch_id, last_sched_batch_cnt, last_received_request_ids = -1, [], []

def _fetch_request():
try:
Expand Down Expand Up @@ -960,11 +984,77 @@ def _fetch_request():
else:
for task in tasks:
self.resource_manager.add_request(task)

if envs.FD_ENABLE_BATCH_SCHEDULER:
with req_info_lock:
for task in tasks:
if "batch_info" not in task.ic_req_data:
continue
batch_info = json.loads(task.ic_req_data["batch_info"])
self.llm_logger.info(f"sched batch info: {batch_info}")
buffered_req_info[task.request_id] = batch_info
is_fetching = False
except Exception as e:
self.llm_logger.error(f"fetching request error {e} {str(traceback.format_exc())}")
is_fetching = False

def _check_recv_full_batch():
with req_info_lock:
all_buffered_req_info = []
dist.all_gather_object(all_buffered_req_info, buffered_req_info)

nonlocal last_sched_batch_id, last_sched_batch_cnt, \
last_received_request_ids, start_time
has_recv_data = last_sched_batch_id != -1
last_received_request_ids = []
# Find the latest scheduled batch
for local_info in all_buffered_req_info:
for _, sched_info in local_info.items():
if sched_info["sched_batch_id"] > last_sched_batch_id:
last_sched_batch_id = sched_info["sched_batch_id"]
last_sched_batch_cnt = sched_info["sched_batch_cnt"]

# Currently no new reqs
if last_sched_batch_id == -1:
return False
# return True

if not has_recv_data:
start_time = time.time()
# Count req num of each DP instance
dp_size = len(last_sched_batch_cnt)
req_num_count = [0] * dp_size
for i, local_info in enumerate(all_buffered_req_info):
for _, sched_info in local_info.items():
if sched_info["sched_batch_id"] == last_sched_batch_id:
req_num_count[i] += 1
last_received_request_ids.append(sched_info["sched_batch_local_id"])

flag = True
for i in range(dp_size):
if req_num_count[i] < last_sched_batch_cnt[i]:
flag = False
break

if flag:
# All reqs in latest batch are received
last_received_request_ids = []
return flag

def _check_timeout():
# All DP instances should use a same timer
# Otherwise, say that we have two instances, instance 0 reaches timeout, while instance 1 not, this
# will cause:
# instance 0: worker stuck at dist.barrier -> engine stuck at waiting infer_finished_signal
# instance 1: engine stuck at all_gather_object
# Now we have a deadlock!
nonlocal start_time
time_list = [time.time() - start_time]
dist.broadcast_object_list(time_list, src=0)

return time_list[0] * 1000 >= envs.FD_RECV_BATCH_TIMEOUT

start_time = time.time()
while self.running:
with self._pause_cond:
self._pause_cond.wait_for(lambda: not self.is_paused)
Expand All @@ -980,15 +1070,22 @@ def _fetch_request():
break
else:
raise
# Continue preprocessing incoming requests and accumulating them in the queue when forward pass not finished.
# Once the forward pass finishes, these accumulated requests can be scheduled in larger,
# more efficient batches.
if not (self.engine_worker_queue.num_tasks() == 0 and self.engine_forward_signal.value[0] == 0):
time.sleep(0.001)
continue

if envs.FD_ENABLE_BATCH_SCHEDULER:
# Some reqs of the scheduled batch still in flight
if not _check_recv_full_batch() and not _check_timeout():
time.sleep(0.01)
continue
else:
if not (self.engine_worker_queue.num_tasks() == 0 and self.infer_finished_signal.value[0] == 0):
time.sleep(0.001)
continue

# 2. Schedule requests
tasks, error_tasks = self.resource_manager.schedule()
# Clear buffered reqs
with req_info_lock:
buffered_req_info.clear()

# 3. Send to engine
if tasks:
Expand Down Expand Up @@ -1033,15 +1130,16 @@ def _fetch_request():
if self.cfg.scheduler_config.splitwise_role == "decode":
task.metrics.decode_inference_start_time = time.time()
else:
task.metrics.inference_start_time = time.time()
if not task.metrics.inference_start_time:
task.metrics.inference_start_time = time.time()
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
else:
# When there are no actual tasks to schedule, send an empty task batch to EP workers.
# This helps EP workers barrier for syncing tasks not hang.
if self.cfg.parallel_config.enable_expert_parallel:
self.engine_worker_queue.put_tasks(
([], self.resource_manager.real_bsz)
) # Empty (as idle tasks for ep)
if envs.FD_ENABLE_BATCH_SCHEDULER or self.cfg.parallel_config.enable_expert_parallel:
# Insert IDLE task for synchronization
idle_task = Request.from_dict({"request_id": f"idle-{uuid.uuid4()}"})
idle_task.task_type = RequestType.IDLE
tasks = [idle_task]
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))

# 4. Response error tasks
if error_tasks:
Expand All @@ -1051,8 +1149,25 @@ def _fetch_request():
continue
self._send_error_response(request_id, failed)

if not tasks and not error_tasks:
time.sleep(0.005)
if envs.FD_ENABLE_BATCH_SCHEDULER:
start_execute_time = time.time()
while self.infer_finished_signal.value[0] == 0:
# Wait for current forward to finish
time.sleep(0.01)
execute_time = int((time.time() - start_execute_time) * 1000)

# Report to IM
if last_sched_batch_id != -1:
self.report_infer_monitor(
last_sched_batch_id,
last_received_request_ids,
execute_time,
self.resource_manager.get_remain_token_num(),
)
self.infer_finished_signal.value[0] = 0
else:
if not tasks and not error_tasks:
time.sleep(0.005)

except RuntimeError as e:
if "cannot schedule new futures after shutdown" in str(e):
Expand All @@ -1061,6 +1176,9 @@ def _fetch_request():
err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc()))
self.llm_logger.error(err_msg)

last_sched_batch_id, last_sched_batch_cnt = -1, []
start_time = time.time()

def start_zmq_service(self, api_server_pid=None):
if api_server_pid is None:
return
Expand Down Expand Up @@ -1710,6 +1828,56 @@ def _register():
register_thread = threading.Thread(target=_register, daemon=True)
register_thread.start()

def report_infer_monitor(
self,
last_sched_batch_id,
last_received_request_ids,
last_run_batch_duration,
remain_token_num,
):
"""
Report info of latest batch to infer monitor
"""
report_info_list = []
dist.communication.stream.gather(
paddle.to_tensor([last_run_batch_duration, remain_token_num]),
report_info_list,
dst=0,
)
if self.cfg.node_rank == 0 and self.cfg.parallel_config.local_data_parallel_id == 0:
# Report by DP0
report_info = paddle.to_tensor(report_info_list)
local_data_parallel_size = len(self.cfg.parallel_config.engine_worker_queue_port)
global_data_parallel_id = self.cfg.node_rank * local_data_parallel_size + self.cfg.parallel_config.local_data_parallel_id
payload = {
"last_sched_batch_id": last_sched_batch_id,
"last_received_request_ids": last_received_request_ids,
"last_run_batch_duration": report_info[:, 0].max().item(),
"remain_token_num_per_dp": report_info[:, 1].tolist(),
# "fed_instance_name": os.getenv("FED_INSTANCE_NAME"),
"fed_instance_name": (
os.getenv("POD_NAMESPACE", "None")
+ "_"
+ os.getenv("FD_POD_NAME", "None")
+ "_"
+ os.getenv("HOST_IP", "None")
+ "_"
+ os.getenv("SPLITWISE_ROLE", "None")
+ "_"
+ str(global_data_parallel_id)
),
"model_id": os.getenv("MODEL_ID"),
}
llm_logger.info(f"report info: {payload}")

try:
url = f"http://10.25.77.31:{envs.FD_REPORT_IM_PORT}/end_forward"
response = requests.post(url, json=payload)
response.raise_for_status()
llm_logger.info(f"report IM successful: {response}")
except Exception as e:
llm_logger.info(f"report IM failed: {e}")

def _exit_sub_services(self):
"""
exit sub services
Expand Down
4 changes: 3 additions & 1 deletion fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def start(self, api_server_pid=None):

self.engine.start()
self.engine.create_data_processor()
if envs.FD_ENABLE_BATCH_SCHEDULER:
self.engine.init_parallel_env()
self.data_processor = self.engine.data_processor

# If block numer is specified and model is deployed in mixed mode, start cache manager first
Expand Down Expand Up @@ -411,7 +413,7 @@ def _init_worker_signals(self):
suffix=self.ipc_signal_suffix,
create=True,
)

def _exit_sub_services(self):
"""
exit sub services
Expand Down
3 changes: 3 additions & 0 deletions fastdeploy/engine/expert_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def start(
)
self.launched_expert_service_signal.value[local_rank] = 1

if envs.FD_ENABLE_BATCH_SCHEDULER:
self.engine.init_parallel_env()

if self.do_profile:
get_profile_block_num = np.zeros([1], dtype=np.int32)
while True:
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class RequestType(Enum):
DECODE = 1
PREEMPTED = 2
EXTEND = 3
IDLE = 4


@dataclass
Expand Down Expand Up @@ -351,6 +352,7 @@ def from_dict(cls, d: dict):
data_processor_logger.error(
f"Convert mm_positions to ImagePosition error: {e}, {str(traceback.format_exc())}"
)

return cls(
request_id=d["request_id"],
prompt=d.get("prompt"),
Expand Down
14 changes: 14 additions & 0 deletions fastdeploy/engine/sched/resource_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,20 @@ def update_metrics(self):
main_process_metrics.num_requests_running.set(len(self.running))
main_process_metrics.num_requests_waiting.set(num_tasks - len(self.running))

def get_remain_token_num(self):
token_num = 0
with self.lock:
for req in self.running:
token_num += req.num_total_tokens - req.num_computed_tokens
for req in self.waiting:
token_num += req.num_total_tokens
for req_id in self.to_be_rescheduled_request_id_set:
# Preempt reqs currently neither in running queue nor in waiting queue
req = self.requests[req_id]
token_num += req.num_total_tokens

return token_num

def log_status(self):
llm_logger.info(
f"ResourceManagerV1( "
Expand Down
6 changes: 6 additions & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,12 @@
"FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")),
# File path for file storage backend
"FILE_BACKEND_STORAGE_DIR": lambda: str(os.getenv("FILE_BACKEND_STORAGE_DIR", "/tmp/fastdeploy")),
# Enable batch scheduler to increase batch size under DP+EP
"FD_ENABLE_BATCH_SCHEDULER": lambda: int(os.getenv("FD_ENABLE_BATCH_SCHEDULER", "0")),
# Timeout for batching reqs, 500ms in default
"FD_RECV_BATCH_TIMEOUT": lambda: int(os.getenv("FD_RECV_BATCH_TIMEOUT", "500")),
# Port for IM reporting
"FD_REPORT_IM_PORT": lambda: int(os.getenv("FD_REPORT_IM_PORT", "9009")),
}


Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/input/ernie4_5_vl_processor/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def _extract_labels(self, outputs: Dict, tgts: List[str]) -> None:

def _load_and_process_video(self, url: str, item: Dict) -> List[Image.Image]:
reader, meta, path = read_video_decord(url, save_to_disk=False)

video_frame_args = dict()
video_frame_args["fps"] = item.get("fps", self.fps)
video_frame_args["min_frames"] = item.get("min_frames", self.min_frames)
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/input/ernie4_5_vl_processor/process_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def read_frames_decord(
fix_start=fix_start,
input_fps=video_meta["fps"],
)

frames = []
for frame_indice_index in range(0, len(frame_indices)):
frame_indice = frame_indices[frame_indice_index]
Expand Down
Loading
Loading