Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review — refactor/queue-handler
严重问题 (Blocker)

  1. fc_stop_client.py 和 fc_stop_client_test.py 是冗余的文件

fc_stop_client.py 与 fc_client.py 的逻辑 100% 重复,唯一区别是没有提取 _create_client 辅助函数。init.py 已经正确地从 fc_client 导出,queue_handler 也使用 fc_client,fc_stop_client 是孤立的死文件。必须删除,否则后续维护两套代码。

  1. queue_handler.py 中 except Exception 吞掉了非超时异常

第 248-249 行

except Exception:
timed_out = True
as_completed(..., timeout=...) 超时抛出的是 concurrent.futures.TimeoutError,但这里捕获了所有 Exception,若 as_completed 内部出现其他异常(如 RuntimeError),也会被误认为超时。应改为:

from concurrent.futures import TimeoutError as FuturesTimeoutError
except FuturesTimeoutError:
timed_out = True
3. queue_handler.py 中 stop_async_task 是函数内延迟导入,违反编码规范

第 218 行(函数体内)

from services.fc_openapi.fc_client import stop_async_task
项目规则 1.2 要求导入统一放在模块顶部,这里放在函数体内会导致每次调用时才解析模块(虽有缓存,但逻辑不清晰),应移至文件顶部。

中等问题 (Major)
4. history_manager.py 中"解析 prompt_body"逻辑重复了两遍

_build_history_item(第 337-354 行)和 late_init_history_item(第 371-384 行)中几乎完全相同的一段解析代码:

outputs_to_execute = []
prompt_dict = prompt_body or {}
if isinstance(prompt_dict, dict):
if "prompt" in prompt_dict and ...:
outputs_to_execute = prompt_dict.get("outputs_to_execute", [])
prompt_dict = prompt_dict["prompt"]
if not outputs_to_execute and isinstance(prompt_dict, dict):
outputs_to_execute = infer_outputs_to_execute(prompt_dict)
raw = prompt_body if isinstance(prompt_body, dict) else {}
extra_data = dict(raw.get("extra_data", {})) ...
应提取为一个私有方法 _parse_prompt_body(prompt_body, client_id) -> (prompt_dict, extra_data, outputs_to_execute) 消除重复。

  1. interrupt_current_user 和 get_current_user_running_task_ids 是死代码

TaskManager 中实现了这两个方法,但 InterruptHandler 直接返回 403,完全没有调用。若这个功能是 "暂不支持",这两个方法不应该存在于此次 PR 中,要么删除,要么在 PR 描述中明确说明"为后续 interrupt 功能预留"。

  1. task_manager.py 中 delete_history_items 存在 TOCTOU race condition

item = self._history_manager.get_history_item(prompt_id) # 加锁读
if item and item.get("user_id") == user_id:
if self._history_manager.remove_history_item(prompt_id): # 加锁写
两次操作各自持锁,但中间无原子性保证,另一线程可能在 get 和 remove 之间删除或修改该 item。应在 HistoryManager 中提供一个原子的 remove_if_owned(prompt_id, user_id) 方法。

轻微问题 (Minor)
7. queue_handler.py 中 _build_task_info 有永远为真的冗余判断

prompt_dict = prompt_body if isinstance(prompt_body, dict) else {}
if isinstance(prompt_dict, dict): # ← 永远为真
if "prompt" in prompt_dict and isinstance(prompt_dict.get("prompt"), dict):
第二个 isinstance(prompt_dict, dict) 可直接删除。

  1. ThreadPoolExecutor 未使用 with 语法,资源管理不健壮

ex = ThreadPoolExecutor(max_workers=workers)
try:
...
finally:
ex.shutdown(wait=not timed_out)
若 ex.submit(...) 之前(futs = {ex.submit(...)} 这行)抛异常,finally 不会执行,线程池泄漏。若用 with ThreadPoolExecutor(...) as ex:,则 exit 保证一定会 shutdown。超时时不想等待的需求,可以通过先 cancel 再 ex.shutdown(wait=False) 实现,不需要手动管理。

  1. requirements.txt 末尾缺少换行符

alibabacloud_tea_util>=0.3.0
\ No newline at end of file
两个新依赖行末都没有换行,应补充 \n,避免工具解析问题。

Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
.vscode/settings.json
reference
test
.cursor
.cursor
**/docs
5 changes: 5 additions & 0 deletions src/code/agent/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@
# GPU 函数的 URL,当 COMFYUI_MODE="cpu" 时使用
GPU_FUNCTION_URL = os.getenv("GPU_FUNCTION_URL", "")

# FC OpenAPI(StopAsyncTask)配置
FC_ACCOUNT_ID = os.getenv("FC_ACCOUNT_ID", "")
FC_REGION = os.getenv("FC_REGION", "cn-hangzhou")
FC_FUNCTION_NAME = os.getenv("FC_FUNCTION_NAME", "")

# HTTP Header 常量
HEADER_SNAPSHOT_NAME = "X-FunArt-Snapshot"
HEADER_FC_INVOCATION_TYPE = "X-FC-Invocation-Type"
Expand Down
4 changes: 3 additions & 1 deletion src/code/agent/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@ oss2
flask-cors
chardet
watchdog
alibabacloud_tea_openapi>=0.3.0
alibabacloud_tea_util>=0.3.0
packaging
pydantic
pydantic
15 changes: 13 additions & 2 deletions src/code/agent/routes/gateway_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from services.gateway.handlers.prompt_handler import PromptHandler
from services.gateway.handlers.serverless_handler import ServerlessHandler
from services.gateway.handlers.history_handler import HistoryHandler
from services.gateway.handlers.interrupt_handler import InterruptHandler
from services.gateway.handlers.reboot_handler import RebootHandler
from services.gateway.handlers.userdata_handler import UserdataHandler
from services.gateway.handlers.ws_handler import WsHandler
Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(self):
self.prompt_handler = PromptHandler(task_manager)
self.serverless_handler = ServerlessHandler()
self.history_handler = HistoryHandler()
self.interrupt_handler = InterruptHandler()
self.userdata_handler = UserdataHandler()
self.ws_handler = WsHandler()
self.serverless_ws_handler = ServerlessWsHandler(constants.GPU_FUNCTION_URL)
Expand All @@ -70,6 +72,7 @@ def setup_routes(self):
self._register_prompt_handler()
self._register_serverless_run_handler()
self._register_history_handler()
self._register_interrupt_handler()
# 通过环境变量控制是否禁用工作流保存
if constants.DISABLE_FLOW_SAVE:
self._register_userdata_handler()
Expand Down Expand Up @@ -222,10 +225,18 @@ def handle_serverless_run():
return self.serverless_handler.handle_post_request()

def _register_history_handler(self):
@self.bp.route("/history", methods=["GET"])
@self.bp.route("/history", methods=["GET", "POST"])
@handle_exceptions(error_type="history_operation_error", log_prefix="History")
def handle_history():
return self.history_handler.handle_get_request()
if request.method == "GET":
return self.history_handler.handle_get_request()
return self.history_handler.handle_post_request()

def _register_interrupt_handler(self):
@self.bp.route("/interrupt", methods=["POST"])
@handle_exceptions(error_type="interrupt_operation_error", log_prefix="Interrupt")
def handle_interrupt():
return self.interrupt_handler.handle_post()

def _register_reboot_handler(self):
@self.bp.route("/manager/reboot", methods=["GET", "POST"])
Expand Down
3 changes: 3 additions & 0 deletions src/code/agent/services/fc_openapi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .fc_client import stop_async_task

__all__ = ["stop_async_task"]
102 changes: 102 additions & 0 deletions src/code/agent/services/fc_openapi/fc_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
FC OpenAPI 客户端
用于在 CPU 侧调用函数计算 OpenAPI,操作 GPU 上的异步任务。
使用实例 RAM 角色凭证:优先从请求头(FC 注入)获取,否则从环境变量获取。
"""
from flask import request

from alibabacloud_tea_openapi.client import Client as OpenApiClient
from alibabacloud_tea_openapi import models as open_api_models
from alibabacloud_tea_util import models as util_models

import constants
from utils.logger import log


def _get_credentials():
"""
获取阿里云凭证(与 ServerlessApiService.get_credentials 一致)。
优先从请求头获取(FC 实例 RAM 角色注入),否则从环境变量获取。
Returns:
tuple: (access_key_id, access_key_secret, security_token)
"""
ak = ""
sk = ""
sts = ""
try:
ak = request.headers.get(constants.HEADER_KEY_ACCESS_KEY_ID, "")
sk = request.headers.get(constants.HEADER_KEY_ACCESS_KEY_SECRET, "")
sts = request.headers.get(constants.HEADER_KEY_SECURITY_TOKEN, "")
except Exception as e:
log("WARNING", f"get credentials from header failed: {e}")
if not ak or not sk:
ak = getattr(constants, "ALIBABA_CLOUD_ACCESS_KEY_ID", "") or ""
sk = getattr(constants, "ALIBABA_CLOUD_ACCESS_KEY_SECRET", "") or ""
sts = getattr(constants, "ALIBABA_CLOUD_SECURITY_TOKEN", "") or ""
return ak, sk, sts


def _create_client(ak: str, sk: str, sts: str, endpoint: str) -> OpenApiClient:
"""
构造 FC OpenApiClient,供各 API 调用方复用,避免重复构建 Config。
"""
config = open_api_models.Config(
access_key_id=ak,
access_key_secret=sk,
security_token=sts or None,
)
config.endpoint = endpoint
return OpenApiClient(config)


def stop_async_task(task_id: str) -> bool:
"""
调用 FC OpenAPI StopAsyncTask 停止指定异步任务(带签名,使用实例 RAM 角色凭证)。

Args:
task_id: 异步任务 ID(与我们的 task_id/prompt_id 一致)。

Returns:
bool: 请求成功且为 2xx 返回 True,否则 False。
"""
account_id = getattr(constants, "FC_ACCOUNT_ID", "") or ""
region = getattr(constants, "FC_REGION", "cn-hangzhou") or "cn-hangzhou"
cpu_function_name = getattr(constants, "FC_FUNCTION_NAME", "") or ""

if not account_id or not cpu_function_name:
return False

ak, sk, sts = _get_credentials()
if not ak or not sk:
log("WARNING", "StopAsyncTask: no credentials (header or env)")
return False

gpu_function_name = cpu_function_name.replace("-gw", "")
endpoint = f"{account_id}.{region}.fc.aliyuncs.com"

try:
client = _create_client(ak, sk, sts, endpoint)

params = open_api_models.Params(
action="StopAsyncTask",
version="2023-03-30",
protocol="HTTPS",
method="PUT",
auth_type="AK",
style="FC",
pathname=f"/2023-03-30/functions/{gpu_function_name}/async-tasks/{task_id}/stop",
req_body_type="json",
body_type="json",
)
req = open_api_models.OpenApiRequest(query={"qualifier": "LATEST"})
runtime = util_models.RuntimeOptions(read_timeout=5000, connect_timeout=5000)

resp = client.call_api(params, req, runtime)
status_code = resp.get("statusCode") or 0
if status_code >= 200 and status_code < 300:
return True
log("WARNING", f"StopAsyncTask failed: task_id={task_id}, status={status_code}, body={resp.get('body', '')[:200]}")
return False
except Exception as e:
log("WARNING", f"StopAsyncTask request error: task_id={task_id}, error={e}")
return False
2 changes: 2 additions & 0 deletions src/code/agent/services/gateway/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .history_handler import HistoryHandler
from .interrupt_handler import InterruptHandler
from .reboot_handler import RebootHandler
from .queue_handler import QueueHandler
from .prompt_handler import PromptHandler
Expand All @@ -8,6 +9,7 @@

__all__ = [
'HistoryHandler',
'InterruptHandler',
'RebootHandler',
'QueueHandler',
'PromptHandler',
Expand Down
16 changes: 15 additions & 1 deletion src/code/agent/services/gateway/handlers/history_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import traceback
from collections import OrderedDict
from typing import Optional, Dict, Any, List
from flask import request, jsonify
from flask import request, jsonify, g

from utils.logger import log

Expand Down Expand Up @@ -36,6 +36,20 @@ def handle_get_request(self):
history = self.task_manager.get_history(max_items=max_items)
return jsonify(history)

def handle_post_request(self):
"""处理 POST /api/history 请求(clear、delete,与 ComfyUI 对齐)"""
if not self._is_initialized():
return "", 503
data = request.get_json(silent=True) or {}
user_id = getattr(g, "user_id", "default")
if data.get("clear"):
self.task_manager.clear_history(user_id)
if "delete" in data:
to_delete = data["delete"]
if isinstance(to_delete, list):
self.task_manager.delete_history_items(to_delete, user_id)
return "", 200

def _is_initialized(self) -> bool:
"""检查服务是否已正确初始化"""
return self.task_manager is not None and self.TaskStatus is not None
Expand Down
21 changes: 21 additions & 0 deletions src/code/agent/services/gateway/handlers/interrupt_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
Interrupt Handler
处理 POST /api/interrupt 请求
"""
from flask import jsonify


class InterruptHandler:
"""处理 POST /api/interrupt"""

def handle_post(self):
"""
POST /api/interrupt 暂不支持:FC StopAsyncTask 会带来非预期体验,
禁止客户端中止正在运行的任务,返回 403。
"""
return jsonify({
"error": {
"type": "not_supported",
"message": "Interrupting a running task is not supported.",
}
}), 403
95 changes: 74 additions & 21 deletions src/code/agent/services/gateway/handlers/queue_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,20 @@
Queue Handler
处理队列相关的请求逻辑
"""
from flask import jsonify, Response, request
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError, as_completed

from flask import Response, copy_current_request_context, jsonify, request

from utils.logger import log
from services.fc_openapi.fc_client import stop_async_task
from services.gateway.task.utils.prompt_utils import parse_prompt_body

# 单次 clear 最多对多少个 PENDING 调 StopAsyncTask,避免请求阻塞过久;其余仍由 clear_queue 本地清除
MAX_PENDING_STOP_ON_CLEAR = 50
# 并行 stop 的最大线程数
MAX_STOP_WORKERS = 10
# 等待所有 stop 调用的总超时(秒),超时后仍会执行 clear_queue
STOP_BATCH_TIMEOUT_SEC = 10


class QueueHandler:
Expand Down Expand Up @@ -38,28 +49,16 @@ def handle_get_request(self):
"queue_pending": [] # 等待中的任务列表
}

# 构造任务信息的辅助函数
# 构造任务信息的辅助函数(与 history 的 prompt 数组格式一致:含 outputs_to_execute)
def _build_task_info(task):
"""构造ComfyUI兼容的任务信息格式"""
# 安全地提取 prompt 和 extra_data
prompt_body = task.prompt_body or {}

# 兼容两种格式:
# 1. 新格式: {"prompt": {...}, "extra_data": {...}}
# 2. 旧格式: 直接是 prompt 工作流定义
if isinstance(prompt_body, dict) and "prompt" in prompt_body:
prompt = prompt_body.get("prompt", {})
extra_data = prompt_body.get("extra_data", {})
else:
prompt = prompt_body
extra_data = {}

prompt_dict, outputs_to_execute, extra_data = parse_prompt_body(task.prompt_body or {})
return [
1, # number - 任务优先级
task.task_id, # prompt_id
prompt or {}, # prompt - 避免None导致序列化失败
prompt_dict or {}, # prompt - 避免None导致序列化失败
extra_data or {}, # extra_data
[] # outputs_to_execute
outputs_to_execute,
]

for task in all_tasks:
Expand All @@ -82,12 +81,66 @@ def handle_post_request(self):
request_data = request.get_json() or {}

if "clear" in request_data and request_data["clear"]:
# 清空队列
# 清空队列:先对当前用户 PENDING 任务并行调 StopAsyncTask(数量上限 + 总超时),再清本地 PENDING(与 ComfyUI 只清 pending 对齐,不停 RUNNING)
log("INFO", f"Clearing task queue")

pending_ids = self.task_manager.get_current_user_pending_task_ids()
ids_to_stop = pending_ids[:MAX_PENDING_STOP_ON_CLEAR]
workers = min(MAX_STOP_WORKERS, len(ids_to_stop) or 1)

@copy_current_request_context
def stop_in_context(task_id):
return stop_async_task(task_id)

stopped = 0
failed_ids = []

if ids_to_stop:
ex = ThreadPoolExecutor(max_workers=workers)
timed_out = False
try:
futs = {ex.submit(stop_in_context, tid): tid for tid in ids_to_stop}
try:
for fut in as_completed(futs, timeout=STOP_BATCH_TIMEOUT_SEC):
tid = futs[fut]
try:
ok = fut.result()
if ok:
stopped += 1
else:
failed_ids.append(tid)
except Exception as e:
log("WARNING", f"StopAsyncTask exception: task_id={tid}, error={e}")
failed_ids.append(tid)
except FuturesTimeoutError:
timed_out = True
# 取消尚未开始的 future,已在运行的计入失败
timed_out_ids = []
for fut, tid in futs.items():
if not fut.done():
fut.cancel()
timed_out_ids.append(tid)
failed_ids.append(tid)
log("WARNING", f"StopAsyncTask batch timed out after {STOP_BATCH_TIMEOUT_SEC}s; timed-out tasks: {sorted(timed_out_ids)}")
finally:
# 超时时 wait=False,不阻塞等待仍在运行的 stop 调用
ex.shutdown(wait=not timed_out)

if failed_ids:
sample = failed_ids[:10]
log(
"WARNING",
f"StopAsyncTask failed for {len(failed_ids)}/{len(ids_to_stop)} tasks "
f"(GPU may still be running). Failed ids (up to 10): {sample}",
)

cleared_count = self.task_manager.clear_queue()
log("INFO", f"Cleared {cleared_count} tasks from queue")

skipped = len(pending_ids) - len(ids_to_stop)
log(
"INFO",
f"Cleared {cleared_count} tasks from queue; "
f"StopAsyncTask stopped={stopped} failed={len(failed_ids)}"
+ (f" skipped(over limit)={skipped}" if skipped > 0 else ""),
)
return Response(status=200)

elif "delete" in request_data:
Expand Down
Loading