Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions astrbot/core/astr_agent_tool_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,35 @@
from astrbot.core.platform.message_session import MessageSession
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.provider.register import llm_tools
from astrbot.core.star.session_plugin_manager import SessionPluginManager
from astrbot.core.star.star import star_map
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.history_saver import persist_agent_history
from astrbot.core.utils.image_ref_utils import is_supported_image_ref
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings


class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@classmethod
def _tool_enabled_for_session(
cls,
tool: FunctionTool,
session_config: dict | None,
) -> bool:
mp = tool.handler_module_path
if not mp:
return True

plugin = star_map.get(mp)
if not plugin:
return True

return SessionPluginManager.is_plugin_enabled_for_session_config(
plugin.name,
session_config,
reserved=plugin.reserved,
)

@classmethod
def _collect_image_urls_from_args(cls, image_urls_raw: T.Any) -> list[str]:
if image_urls_raw is None:
Expand Down Expand Up @@ -193,14 +215,17 @@ def _get_runtime_computer_tools(cls, runtime: str) -> dict[str, FunctionTool]:
return {}

@classmethod
def _build_handoff_toolset(
async def _build_handoff_toolset(
cls,
run_context: ContextWrapper[AstrAgentContext],
tools: list[str | FunctionTool] | None,
) -> ToolSet | None:
ctx = run_context.context.context
event = run_context.context.event
cfg = ctx.get_config(umo=event.unified_msg_origin)
session_config = await SessionPluginManager.get_session_plugin_config(
event.unified_msg_origin
)
provider_settings = cfg.get("provider_settings", {})
runtime = str(provider_settings.get("computer_use_runtime", "local"))
runtime_computer_tools = cls._get_runtime_computer_tools(runtime)
Expand All @@ -212,7 +237,10 @@ def _build_handoff_toolset(
for registered_tool in llm_tools.func_list:
if isinstance(registered_tool, HandoffTool):
continue
if registered_tool.active:
if registered_tool.active and cls._tool_enabled_for_session(
registered_tool,
session_config,
):
toolset.add_tool(registered_tool)
for runtime_tool in runtime_computer_tools.values():
toolset.add_tool(runtime_tool)
Expand All @@ -225,14 +253,19 @@ def _build_handoff_toolset(
for tool_name_or_obj in tools:
if isinstance(tool_name_or_obj, str):
registered_tool = llm_tools.get_func(tool_name_or_obj)
if registered_tool and registered_tool.active:
if (
registered_tool
and registered_tool.active
and cls._tool_enabled_for_session(registered_tool, session_config)
):
toolset.add_tool(registered_tool)
continue
runtime_tool = runtime_computer_tools.get(tool_name_or_obj)
if runtime_tool:
toolset.add_tool(runtime_tool)
elif isinstance(tool_name_or_obj, FunctionTool):
toolset.add_tool(tool_name_or_obj)
if cls._tool_enabled_for_session(tool_name_or_obj, session_config):
toolset.add_tool(tool_name_or_obj)
return None if toolset.empty() else toolset

@classmethod
Expand Down Expand Up @@ -264,7 +297,7 @@ async def _execute_handoff(
tool_args["image_urls"] = image_urls

# Build handoff toolset from registered tools plus runtime computer tools.
toolset = cls._build_handoff_toolset(run_context, tool.agent.tools)
toolset = await cls._build_handoff_toolset(run_context, tool.agent.tools)

ctx = run_context.context.context
event = run_context.context.event
Expand Down
63 changes: 40 additions & 23 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.skills.skill_manager import SkillManager, build_skills_prompt
from astrbot.core.star.context import Context
from astrbot.core.star.session_plugin_manager import SessionPluginManager
from astrbot.core.star.star_handler import star_map
from astrbot.core.tools.cron_tools import (
CREATE_CRON_JOB_TOOL,
Expand Down Expand Up @@ -846,33 +847,49 @@ def _sanitize_context_by_modalities(
req.contexts = sanitized_contexts


def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
async def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
"""根据事件中的插件设置,过滤请求中的工具列表。

注意:没有 handler_module_path 的工具(如 MCP 工具)会被保留,
因为它们不属于任何插件,不应被插件过滤逻辑影响。
"""
if event.plugins_name is not None and req.func_tool:
new_tool_set = ToolSet()
for tool in req.func_tool.tools:
if isinstance(tool, MCPTool):
# 保留 MCP 工具
new_tool_set.add_tool(tool)
continue
mp = tool.handler_module_path
if not mp:
# 没有 plugin 归属信息的工具(如 subagent transfer_to_*)
# 不应受到会话插件过滤影响。
new_tool_set.add_tool(tool)
continue
plugin = star_map.get(mp)
if not plugin:
# 无法解析插件归属时,保守保留工具,避免误过滤。
new_tool_set.add_tool(tool)
continue
if plugin.name in event.plugins_name or plugin.reserved:
new_tool_set.add_tool(tool)
req.func_tool = new_tool_set
if not req.func_tool:
return

session_config = await SessionPluginManager.get_session_plugin_config(
event.unified_msg_origin
)
new_tool_set = ToolSet()
for tool in req.func_tool.tools:
if isinstance(tool, MCPTool):
# 保留 MCP 工具
new_tool_set.add_tool(tool)
continue
mp = tool.handler_module_path
if not mp:
# 没有 plugin 归属信息的工具(如 subagent transfer_to_*)
# 不应受到会话插件过滤影响。
new_tool_set.add_tool(tool)
continue
plugin = star_map.get(mp)
if not plugin:
# 无法解析插件归属时,保守保留工具,避免误过滤。
new_tool_set.add_tool(tool)
continue
if (
event.plugins_name is not None
and not plugin.reserved
and plugin.name not in event.plugins_name
):
continue
if not SessionPluginManager.is_plugin_enabled_for_session_config(
plugin.name,
session_config,
reserved=plugin.reserved,
):
continue
new_tool_set.add_tool(tool)
req.func_tool = new_tool_set


async def _handle_webchat(
Expand Down Expand Up @@ -1243,7 +1260,7 @@ async def build_main_agent(
req.session_id = event.unified_msg_origin

_modalities_fix(provider, req)
_plugin_tool_fix(event, req)
await _plugin_tool_fix(event, req)
_sanitize_context_by_modalities(config, provider, req)

if config.llm_safety_mode:
Expand Down
18 changes: 16 additions & 2 deletions astrbot/core/pipeline/context_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from astrbot import logger
from astrbot.core.message.message_event_result import CommandResult, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.session_plugin_manager import SessionPluginManager
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry

Expand Down Expand Up @@ -89,19 +90,32 @@ async def call_event_hook(
hook_type,
plugins_name=event.plugins_name,
)
session_config = await SessionPluginManager.get_session_plugin_config(
event.unified_msg_origin
)
for handler in handlers:
plugin = star_map.get(handler.handler_module_path)
if plugin and not SessionPluginManager.is_plugin_enabled_for_session_config(
plugin.name,
session_config,
reserved=plugin.reserved,
):
logger.debug(
f"插件 {plugin.name} 在会话 {event.unified_msg_origin} 中被禁用,跳过 hook {handler.handler_name}",
)
continue
try:
assert inspect.iscoroutinefunction(handler.handler)
logger.debug(
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
f"hook({hook_type.name}) -> {plugin.name if plugin else handler.handler_module_path} - {handler.handler_name}",
)
await handler.handler(event, *args, **kwargs)
except BaseException:
logger.error(traceback.format_exc())

if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。",
f"{plugin.name if plugin else handler.handler_module_path} - {handler.handler_name} 终止了事件传播。",
)
return True

Expand Down
77 changes: 50 additions & 27 deletions astrbot/core/star/session_plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,43 +8,67 @@ class SessionPluginManager:
"""管理会话级别的插件启停状态"""

@staticmethod
async def is_plugin_enabled_for_session(
session_id: str,
plugin_name: str,
) -> bool:
"""检查插件是否在指定会话中启用

Args:
session_id: 会话ID (unified_msg_origin)
plugin_name: 插件名称

Returns:
bool: True表示启用,False表示禁用

"""
# 获取会话插件配置
async def get_session_plugin_config(session_id: str) -> dict:
"""获取指定会话的插件配置。"""
session_plugin_config = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_plugin_config",
default={},
)
session_config = session_plugin_config.get(session_id, {})
return session_plugin_config.get(session_id, {})

@staticmethod
def is_plugin_enabled_for_session_config(
plugin_name: str | None,
session_config: dict | None,
*,
reserved: bool = False,
) -> bool:
"""检查插件是否在指定会话配置中启用。"""
if reserved or not plugin_name:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question (bug_risk): Re-check whether reserved plugins should ignore session disable configuration.

This now always returns True for reserved plugins, so they can no longer be disabled via disabled_plugins, which is a behavior change from the previous implementation. Please clarify whether reserved plugins are meant to be session-disableable:

  • If yes, remove reserved from this short-circuit and rely solely on enabled/disabled_plugins.
  • If no, confirm that this “non-disableable” behavior is consistent with other call sites and that nothing depends on disabling reserved plugins per session.

return True

if not session_config:
return True

enabled_plugins = session_config.get("enabled_plugins", [])
disabled_plugins = session_config.get("disabled_plugins", [])

# 如果插件在禁用列表中,返回False
if plugin_name in disabled_plugins:
return False

# 如果插件在启用列表中,返回True
if plugin_name in enabled_plugins:
return True

# 如果都没有配置,默认为启用(兼容性考虑)
return True

@staticmethod
async def is_plugin_enabled_for_session(
session_id: str,
plugin_name: str,
*,
reserved: bool = False,
) -> bool:
"""检查插件是否在指定会话中启用

Args:
session_id: 会话ID (unified_msg_origin)
plugin_name: 插件名称

Returns:
bool: True表示启用,False表示禁用

"""
session_config = await SessionPluginManager.get_session_plugin_config(
session_id
)
return SessionPluginManager.is_plugin_enabled_for_session_config(
plugin_name,
session_config,
reserved=reserved,
)

@staticmethod
async def filter_handlers_by_session(
event: AstrMessageEvent,
Expand All @@ -65,14 +89,9 @@ async def filter_handlers_by_session(
session_id = event.unified_msg_origin
filtered_handlers = []

session_plugin_config = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_plugin_config",
default={},
session_config = await SessionPluginManager.get_session_plugin_config(
session_id
)
session_config = session_plugin_config.get(session_id, {})
disabled_plugins = session_config.get("disabled_plugins", [])

for handler in handlers:
# 获取处理器对应的插件
Expand All @@ -91,7 +110,11 @@ async def filter_handlers_by_session(
continue

# 检查插件是否在当前会话中启用
if plugin.name in disabled_plugins:
if not SessionPluginManager.is_plugin_enabled_for_session_config(
plugin.name,
session_config,
reserved=plugin.reserved,
):
logger.debug(
f"插件 {plugin.name} 在会话 {session_id} 中被禁用,跳过处理器 {handler.handler_name}",
)
Expand Down
Loading
Loading