Skip to content
Open
26 changes: 24 additions & 2 deletions lightllm/server/api_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from .api_lightllm import lightllm_get_score
from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size
from lightllm.utils.log_utils import init_logger
from lightllm.utils.error_utils import ServerBusyError
from lightllm.utils.error_utils import ClientDisconnected, ServerBusyError
from lightllm.server.metrics.manager import MetricClient
from lightllm.utils.envs_utils import get_unique_server_name
from dataclasses import dataclass
Expand Down Expand Up @@ -244,6 +244,9 @@ async def generate(request: Request) -> Response:
return create_error_response(HTTPStatus.SERVICE_UNAVAILABLE, str(e))
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
except ClientDisconnected as e:
logger.error(str(e))
return Response(status_code=499)
except Exception as e:
logger.error("An error occurred: %s", str(e), exc_info=True)
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))
Expand All @@ -263,6 +266,9 @@ async def generate_stream(request: Request) -> Response:
return create_error_response(HTTPStatus.SERVICE_UNAVAILABLE, str(e))
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
except ClientDisconnected as e:
logger.error(str(e))
return Response(status_code=499)
except Exception as e:
logger.error("An error occurred: %s", str(e), exc_info=True)
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))
Expand All @@ -277,6 +283,9 @@ async def get_score(request: Request) -> Response:

try:
return await lightllm_get_score(request, g_objs.httpserver_manager)
except ClientDisconnected as e:
logger.error(str(e))
return Response(status_code=499)
except Exception as e:
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))

Expand Down Expand Up @@ -307,6 +316,9 @@ async def chat_completions(request: ChatCompletionRequest, raw_request: Request)
resp = await chat_completions_impl(request, raw_request)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
except ClientDisconnected as e:
logger.error(str(e))
return Response(status_code=499)
return resp


Expand All @@ -321,6 +333,9 @@ async def completions(request: CompletionRequest, raw_request: Request) -> Respo
resp = await completions_impl(request, raw_request)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
except ClientDisconnected as e:
logger.error(str(e))
return Response(status_code=499)
return resp


Expand All @@ -332,7 +347,11 @@ async def anthropic_messages(raw_request: Request) -> Response:
)
from .api_anthropic import anthropic_messages_impl

return await anthropic_messages_impl(raw_request)
try:
return await anthropic_messages_impl(raw_request)
except ClientDisconnected as e:
logger.error(str(e))
return Response(status_code=499)


@app.get("/v1/models", response_model=ModelListResponse)
Expand Down Expand Up @@ -377,6 +396,9 @@ async def tokens(request: Request):
},
status_code=200,
)
except ClientDisconnected as e:
logger.error(str(e))
return Response(status_code=499)
except Exception as e:
return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}")

Expand Down
5 changes: 5 additions & 0 deletions lightllm/server/api_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster
from .api_lightllm import lightllm_get_score
from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size
from lightllm.utils.error_utils import ClientDisconnected

from lightllm.utils.log_utils import init_logger
from lightllm.server.metrics.manager import MetricClient
Expand Down Expand Up @@ -68,6 +69,10 @@ async def _safe_stream_wrapper(stream_generator):
except ValueError as e:
error_data = json.dumps({"error": {"message": str(e), "type": "invalid_request_error"}}, ensure_ascii=False)
yield f"data: {error_data}\n\n"
except ClientDisconnected as e:
logger.error(str(e))
# Client is gone — there's no point yielding more SSE chunks. Stop quietly.
return


def _serialize_sse_chunk(chunk, choice_nulls=(), response_nulls=()):
Expand Down
13 changes: 10 additions & 3 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from lightllm.utils.statics_utils import MovingAverage
from lightllm.utils.config_utils import get_vocab_size
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken
from lightllm.utils.error_utils import ClientDisconnected, NixlPrefillNodeStopGenToken
from rpyc.utils.classic import obtain

logger = init_logger(__name__)
Expand Down Expand Up @@ -445,8 +445,13 @@ async def generate(

yield sub_req_id, request_output, metadata, finish_status

except Exception as e:
except (ClientDisconnected, Exception) as e:
logger.error(f"group_request_id: {group_request_id} has exception {str(e)}")

if isinstance(e, ClientDisconnected):
logger.warning(f"group_request_id: {group_request_id} {e.reason}")
logger.debug(f"group_request_id: {group_request_id} {e.reason}", exc_info=True)

# error need to release multimodel resources.
# 对于还没有形成正式请求对象管理的多模态资源,需要单独自己释放
# 已经放入到 req_id_to_out_inf 中的请求对象,由统一的回收循环
Expand Down Expand Up @@ -664,7 +669,9 @@ async def _wait_to_token_package(

if not self.disable_abort and request is not None and await request.is_disconnected():
await self.abort(group_request_id)
raise Exception(f"req_id {group_request_id} disconnected")
raise ClientDisconnected(
group_request_id=group_request_id, reason="_wait_to_token_package check network disconnected"
)

async with req_status.lock:
event.clear()
Expand Down
30 changes: 23 additions & 7 deletions lightllm/server/httpserver_for_pd_master/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from lightllm.server.metrics.manager import MetricClient
from lightllm.utils.statics_utils import MovingAverage
from lightllm.server.httpserver.manager import AsyncQueue
from lightllm.utils.error_utils import ServerBusyError
from lightllm.utils.error_utils import ClientDisconnected, ServerBusyError
from lightllm.utils.envs_utils import get_pd_split_max_new_tokens
from .pd_selector import create_selector

Expand Down Expand Up @@ -163,8 +163,13 @@ async def generate(

await self.remove_req(group_request_id=block_group_request_id)

except BaseException as e:
except (ClientDisconnected, BaseException) as e:
logger.error(f"has exception {str(e)}")

if isinstance(e, ClientDisconnected):
logger.warning(f"group_request_id: {origin_group_request_id} {e.reason}")
logger.debug(f"group_request_id: {origin_group_request_id} {e.reason}", exc_info=True)

try:
await self.abort(block_group_request_id, p_node=p_node, d_node=d_node)
except:
Expand Down Expand Up @@ -221,7 +226,9 @@ async def fetch_stream(
while True:
await req_status.wait_to_ready()
if await request.is_disconnected():
raise Exception(f"req_id {group_request_id} disconnected")
raise ClientDisconnected(
group_request_id=group_request_id, reason="fetch_stream prefill period check network disconnected"
)

if await req_status.can_read(self.req_id_to_out_inf):
token_list = await req_status.pop_all_tokens()
Expand Down Expand Up @@ -259,7 +266,9 @@ async def fetch_stream(
while True:
await req_status.wait_to_ready()
if await request.is_disconnected():
raise Exception(f"req_id {group_request_id} disconnected")
raise ClientDisconnected(
group_request_id=group_request_id, reason="fetch_stream decode period check network disconnected"
)
if await req_status.can_read(self.req_id_to_out_inf):
token_list = await req_status.pop_all_tokens()
for sub_req_id, request_output, metadata, finish_status in token_list:
Expand Down Expand Up @@ -296,7 +305,9 @@ async def fetch_nixl_stream(
raise ServerBusyError()

if await request.is_disconnected():
raise Exception(f"req_id {group_request_id} disconnected")
raise ClientDisconnected(
group_request_id=group_request_id, reason="fetch_nixl_stream prefill period check network disconnected"
)

prompt_ids = nixl_np_up_prompt_ids_event.prompt_ids
logger.info(f"group_request_id: {group_request_id} get np up prompt ids len {len(prompt_ids)}")
Expand Down Expand Up @@ -324,7 +335,10 @@ async def fetch_nixl_stream(
while True:
await req_status.wait_to_ready()
if await request.is_disconnected():
raise Exception(f"req_id {group_request_id} disconnected")
raise ClientDisconnected(
group_request_id=group_request_id,
reason="fetch_nixl_stream decode period check network disconnected",
)
if await req_status.can_read(self.req_id_to_out_inf):
token_list = await req_status.pop_all_tokens()
for sub_req_id, request_output, metadata, finish_status in token_list:
Expand Down Expand Up @@ -373,7 +387,9 @@ async def _wait_to_token_package(
p_node, d_node, prompt, sampling_params, multimodal_params, request
):
if await request.is_disconnected():
raise Exception(f"req_id {group_request_id} disconnected")
raise ClientDisconnected(
group_request_id=group_request_id, reason="_wait_to_token_package check network disconnected"
)

prompt_tokens = metadata["prompt_tokens"]
out_token_counter += 1
Expand Down
9 changes: 9 additions & 0 deletions lightllm/server/multimodal_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from PIL import Image
from fastapi import Request
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
from lightllm.utils.error_utils import ClientDisconnected
from lightllm.utils.multimodal_utils import fetch_resource
from lightllm.utils.log_utils import init_logger

Expand Down Expand Up @@ -63,6 +64,10 @@ async def preload(self, request: Request):
self._preload_data = audio_values.tobytes()
return

except ClientDisconnected as e:
# Preserve client-disconnect signal so the API layer can return 499
# without the noisy 'Failed to read audio' error logs.
raise e
except Exception as e:
raise ValueError(f"Failed to read audio type={self._type}, data[:100]={self._data[:100]}: {e}!")

Expand Down Expand Up @@ -148,6 +153,10 @@ async def preload(self, request: Request):
self._preload_data = img_data
return

except ClientDisconnected as e:
# Preserve client-disconnect signal so the API layer can return 499
# without the noisy 'Failed to read image' error logs.
raise e
except Exception as e:
raise ValueError(f"Failed to read image type={self._type}, data[:100]={self._data[:100]}: {e}!")

Expand Down
15 changes: 15 additions & 0 deletions lightllm/utils/error_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from lightllm.utils.log_utils import init_logger
from typing import Optional

logger = init_logger(__name__)

Expand All @@ -23,6 +24,20 @@ def __str__(self):
return f"{self.message} (Status code: {self.status_code})"


class ClientDisconnected(Exception):
"""Raised when the client closed the HTTP connection mid-request, as
detected by ``request.is_disconnected()``. This is an expected control-flow
signal — handlers should clean up quietly without logging a stack trace.
Internal-module aborts (e.g. visual proxy failures) must NOT raise this —
they should surface as real server errors."""

def __init__(self, group_request_id: Optional[int] = None, reason: str = "client disconnected"):
prefix = f"req_id {group_request_id} " if group_request_id is not None else ""
super().__init__(f"{prefix}{reason}")
self.group_request_id = group_request_id
self.reason = reason


class NixlPrefillNodeStopGenToken(Exception):
def __init__(self, group_request_id, message="Nixl prefill node stop gen token"):
"""
Expand Down
3 changes: 2 additions & 1 deletion lightllm/utils/multimodal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from io import BytesIO
from fastapi import Request
from functools import lru_cache
from lightllm.utils.error_utils import ClientDisconnected
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)
Expand Down Expand Up @@ -53,7 +54,7 @@ async def fetch_resource(url, request: Request, timeout, proxy=None):
async for chunk in response.aiter_bytes(chunk_size=1024 * 1024):
if request is not None and await request.is_disconnected():
await response.aclose()
raise Exception("Request disconnected. User cancelled download.")
raise ClientDisconnected(reason=f"client disconnected during download of {url}")
ans_bytes.append(chunk)
# 接收的数据不能大于128M
if len(ans_bytes) > 128:
Expand Down
Loading