Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import bisect
import triton
from typing import Optional
from tqdm import tqdm
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup
Expand Down Expand Up @@ -203,7 +204,11 @@ def warmup(self, model):
model: TpPartBaseModel = model

# decode cuda graph init
for batch_size in self.cuda_graph_batch_sizes[::-1]:
progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing CUDA graphs")
for batch_size in progress_bar:
avail_mem, _ = torch.cuda.mem_get_info()
avail_mem_gb = avail_mem / (1024 ** 3)
progress_bar.set_description(f"Capturing CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB")
seq_len = 2
total_token_num = batch_size * seq_len
max_len_in_batch = self.graph_max_len_in_batch
Expand Down Expand Up @@ -258,7 +263,13 @@ def warmup_overlap(self, model):

model: TpPartBaseModel = model

for batch_size in self.cuda_graph_batch_sizes[::-1]:
progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing overlap CUDA graphs")
for batch_size in progress_bar:
avail_mem, _ = torch.cuda.mem_get_info()
avail_mem_gb = avail_mem / (1024 ** 3)
progress_bar.set_description(
f"Capturing overlap CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB"
)
decode_batches = []
for micro_batch_index in [0, 1]:
# dummy decoding, capture the cudagraph
Expand Down
6 changes: 3 additions & 3 deletions lightllm/common/triton_utils/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _try_load_cache(self, static_key):

cache_file = os.path.join(self.cache_dir, KernelConfigs.get_config_file_name(static_key))
if os.path.exists(cache_file):
logger.info(f"Loading cached configs for {self.kernel_name} - {static_key}")
logger.info(f"Loading cached configs for {self.kernel_name} - {dict(static_key)}")
with open(cache_file, "rb") as f:
self.cached_configs[static_key] = orjson.loads(f.read())
return True
Expand Down Expand Up @@ -353,9 +353,9 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS | orjson.OPT_NON_STR_KEYS,
)
)
logger.info(f"Saved configs for {self.kernel_name} - {_static_key}")
logger.info(f"Saved configs for {self.kernel_name} - {dict(_static_key)}")

logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {static_key} finished")
logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {dict(static_key)} finished")

def _mutate_args_clone(self, args, kwargs):
origin_list = []
Expand Down
1 change: 0 additions & 1 deletion lightllm/server/api_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def set_args(self, args: StartArgs):
app = FastAPI()
g_objs.app = app

_ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"}
_ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"}
_ACCESS_LOG_RESET = "\033[0m"

Expand Down
6 changes: 0 additions & 6 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,6 @@ def normal_or_p_d_start(args):
f"{args.host}:{args.port}",
"--log-level",
"info",
"--access-logfile",
"-",
"--error-logfile",
"-",
"lightllm.server.api_http:app",
Expand Down Expand Up @@ -513,8 +511,6 @@ def pd_master_start(args):
f"{args.host}:{args.port}",
"--log-level",
"info",
"--access-logfile",
"-",
"--error-logfile",
"-",
"lightllm.server.api_http:app",
Expand Down Expand Up @@ -609,8 +605,6 @@ def config_server_start(args):
f"{args.config_server_host}:{args.config_server_port}",
"--log-level",
"info",
"--access-logfile",
"-",
"--error-logfile",
"-",
"lightllm.server.config_server.api_http:app",
Expand Down
7 changes: 5 additions & 2 deletions lightllm/server/core/objs/req.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,20 @@ class FinishStatus(ctypes.Structure):
NO_FINISH = 0
FINISHED_STOP = 1
FINISHED_LENGTH = 2
FINISHED_ERROR = 3

def __init__(self, init_state=NO_FINISH):
self.status = init_state

def set_status(self, new_status):
assert 0 <= new_status <= 2
assert 0 <= new_status <= 3
self.status = new_status

def get_status(self):
return self.status

def is_finished(self):
return self.FINISHED_STOP <= self.status <= self.FINISHED_LENGTH
return self.FINISHED_STOP <= self.status <= self.FINISHED_ERROR

def is_stopped(self):
return self.status == self.FINISHED_STOP
Expand All @@ -50,6 +51,8 @@ def get_finish_reason(self):
return "stop"
elif self.status == self.FINISHED_LENGTH:
return "length"
elif self.status == self.FINISHED_ERROR:
return "error"
return None


Expand Down
46 changes: 28 additions & 18 deletions lightllm/server/detokenization/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,30 @@ def _init_get_token_id_to_token_str(self):
return

def _add_new_group_req_index(self, recv_obj: GroupReqIndexes):
from lightllm.server.core.objs import FinishStatus

for req_index in recv_obj.shm_req_indexes:
req = self.shm_req_manager.get_req_obj_by_index(req_index)
req.link_prompt_ids_shm_array()
req.link_logprobs_shm_array()

logger.info(
f"detokenization recv req id {req.request_id} " f"cost time {time.time() - recv_obj.time_mark} s"
)

# p d 分离模式,decode节点的解码需要做一些特殊的修复。
decode_req = DecodeReq(req, self.is_pd_decode_mode)
if self.is_pd_decode_mode:
decode_req = decode_mode_fix(decode_req, self.tokenizer, self.eos_id)
# token_healing mode 的特殊初始化
if self.args.token_healing_mode:
decode_req.init_token_healing_prefix_str(self.token_id_to_token, self.tokenizer)

self.req_id_to_out[req.request_id] = decode_req
try:
req.link_prompt_ids_shm_array()
req.link_logprobs_shm_array()

logger.debug(
f"detokenization recv req id {req.request_id} " f"cost time {time.time() - recv_obj.time_mark} s"
)

# p d 分离模式,decode节点的解码需要做一些特殊的修复。
decode_req = DecodeReq(req, self.is_pd_decode_mode)
if self.is_pd_decode_mode:
decode_req = decode_mode_fix(decode_req, self.tokenizer, self.eos_id)
# token_healing mode 的特殊初始化
if self.args.token_healing_mode:
decode_req.init_token_healing_prefix_str(self.token_id_to_token, self.tokenizer)

self.req_id_to_out[req.request_id] = decode_req
except Exception as e:
req.finish_status.set_status(FinishStatus.FINISHED_ERROR)
raise e
return

def handle_loop(self):
Expand All @@ -76,7 +82,11 @@ def handle_loop(self):
for _ in range(recv_max_count):
recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
assert isinstance(recv_obj, GroupReqIndexes)
self._add_new_group_req_index(recv_obj=recv_obj)
try:
self._add_new_group_req_index(recv_obj=recv_obj)
except Exception:
logger.exception("add new group req index has exception")
self.pub_to_httpserver.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL)

# 当队列中存在较多的请求时,将一次接受的数量上调
recv_max_count = min(int(recv_max_count * 1.3), 256)
Expand Down Expand Up @@ -160,7 +170,7 @@ def remove_finished_reqs(self):

for decode_req in finished_reqs:
decode_req.req.can_released_mark = True
logger.info(f"detoken release req id {decode_req.req.request_id}")
logger.debug(f"detoken release req id {decode_req.req.request_id}")
self.shm_req_manager.put_back_req_obj(decode_req.req)
self.req_id_to_out.pop(decode_req.request_id, None)
return
Expand Down
25 changes: 21 additions & 4 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,17 @@ async def generate(
# 用于等待 pd_master 下发的交换信息
nixl_pd_event: asyncio.Event = None,
) -> AsyncGenerator[Tuple[int, str, dict, FinishStatus], None]:
group_request_id = None
if isinstance(prompt, str):
# Guard against extremely long string prompts that might stall the tokenizer
# or cause excessive memory usage before tokenization.
# 8 characters per token is a conservative heuristic (avg is ~4).
max_prompt_chars = self.max_req_total_len * 8
if len(prompt) > max_prompt_chars:
raise ValueError(
f"prompt text length {len(prompt)} exceeds the character limit {max_prompt_chars}, "
f"the request is rejected before tokenization."
)

start_time = time.time()
request_headers = request.headers if request is not None else {}
Expand Down Expand Up @@ -445,6 +456,12 @@ async def generate(

yield sub_req_id, request_output, metadata, finish_status

except ValueError as e:
logger.warning(f"group_request_id: {group_request_id} request invalid: {str(e)}")
if group_request_id not in self.req_id_to_out_inf:
await self._release_multimodal_resources(multimodal_params)
await self.abort(group_request_id)
raise e
except Exception as e:
logger.error(f"group_request_id: {group_request_id} has exception {str(e)}")
# error need to release multimodel resources.
Expand Down Expand Up @@ -477,7 +494,7 @@ async def _log_req_header(self, request_headers, group_request_id: int):
x_session_id = request_headers.get("X-Session-Id", "")

format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S")
logger.info(
logger.debug(
f"received req X-Request-Id:{x_request_id} "
f"X-Session-Id:{x_session_id} start_time:{format_in_time} "
f"lightllm_req_id:{group_request_id} "
Expand Down Expand Up @@ -713,7 +730,7 @@ async def _wait_to_token_package(
(out_token_counter - sum(sub_req_id_to_mtp_accepted_token_num.values())), 1
)
format_start_time = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
logger.info(
logger.debug(
f"X-Request-Id:{x_request_id} "
f"X-Session-Id:{x_session_id} start_time:{format_start_time} "
f"lightllm_req_id:{group_request_id} first_token_cost:{first_token_cost_ms}ms "
Expand Down Expand Up @@ -806,8 +823,8 @@ async def recycle_resource_loop(self):
if req_status is None:
continue

logger.info(
f"left req id {req_status.group_req_objs.group_req_id}"
logger.debug(
f"left req id {req_status.group_req_objs.group_req_id} "
f"can release {req_status.group_req_objs.shm_req_objs[0].can_released_mark} "
f"refcount {req_status.group_req_objs.shm_req_objs[0].ref_count}"
)
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/httpserver_for_pd_master/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ async def _log_req_header(self, request: Request, group_request_id: int):
x_request_id = request.headers.get("X-Request-Id", "")
x_session_id = request.headers.get("X-Session-Id", "")
format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S")
logger.info(
logger.debug(
f"received req X-Request-Id:{x_request_id} "
f"X-Session-Id:{x_session_id} start_time:{format_in_time} "
f"lightllm_req_id:{group_request_id} "
Expand Down
8 changes: 2 additions & 6 deletions lightllm/server/router/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Dict, List, Optional, Tuple, Union
from lightllm.server.core.objs import ShmReqManager, Req
from lightllm.utils.log_utils import init_logger
from .stats import RouterStatics

logger = init_logger(__name__)

Expand Down Expand Up @@ -50,14 +49,11 @@ def get_all_dp_req_num(self) -> List[int]:
all_dp_req_num[req.sample_params.suggested_dp_index] += 1
return all_dp_req_num

def filter_out_finished_req(self, shm_req_manager: ShmReqManager, router_statics: RouterStatics):
def filter_out_finished_req(self, shm_req_manager: ShmReqManager):
unfinished_req_ids = []
for req in self.reqs:
if req.shm_infer_released:
logger.info(f"router release req id {req.request_id}")
if not req.is_aborted:
router_statics.update(req.candetoken_out_len)

logger.debug(f"router release req id {req.request_id}")
shm_req_manager.put_back_req_obj(req)
req = None
else:
Expand Down
Loading
Loading