Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 114 additions & 4 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import sys
import time
import traceback
Expand All @@ -14,6 +15,7 @@

from astrbot import logger
from astrbot.core.agent.message import TextPart, ThinkPart
from astrbot.core.agent.tool import ToolSet
from astrbot.core.message.components import Json
from astrbot.core.message.message_event_result import (
MessageChain,
Expand Down Expand Up @@ -43,6 +45,103 @@


class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
def _build_active_tool_set(self, tool_set: ToolSet) -> ToolSet:
active_set = ToolSet()
for tool in tool_set.tools:
if getattr(tool, "active", True):
active_set.add_tool(tool)
return active_set

def _apply_tool_schema_mode(self, tool_schema_mode: str | None) -> None:
tool_set = self.req.func_tool
if not isinstance(tool_set, ToolSet):
return

active_set = self._build_active_tool_set(tool_set)
if not active_set.tools:
self.req.func_tool = active_set
return

self._tool_schema_full_set = active_set

if tool_schema_mode in (None, "full"):
self.req.func_tool = active_set
return

light_set = active_set.get_light_tool_set()
self._tool_schema_param_set = active_set.get_param_only_tool_set()
self.req.func_tool = light_set

def _build_tool_requery_context(
self, tool_names: list[str]
) -> list[dict[str, T.Any]]:
contexts: list[dict[str, T.Any]] = []
for msg in self.run_context.messages:
if hasattr(msg, "model_dump"):
contexts.append(msg.model_dump()) # type: ignore[call-arg]
elif isinstance(msg, dict):
contexts.append(copy.deepcopy(msg))
instruction = (
"You have decided to call tool(s): "
+ ", ".join(tool_names)
+ ". Now call the tool(s) with required arguments using the tool schema, "
"and follow the existing tool-use rules."
)
if contexts and contexts[0].get("role") == "system":
content = contexts[0].get("content") or ""
contexts[0]["content"] = f"{content}\n{instruction}"
else:
contexts.insert(0, {"role": "system", "content": instruction})
return contexts

def _build_tool_subset(self, tool_set: ToolSet, tool_names: list[str]) -> ToolSet:
subset = ToolSet()
for name in tool_names:
tool = tool_set.get_tool(name)
if tool:
subset.add_tool(tool)
return subset

async def _resolve_tool_exec(
self,
llm_resp: LLMResponse,
) -> tuple[LLMResponse, ToolSet | None]:
tool_names = llm_resp.tools_call_name
if not tool_names:
return llm_resp, self.req.func_tool
full_tool_set = self._tool_schema_full_set or self.req.func_tool
if not isinstance(full_tool_set, ToolSet):
return llm_resp, self.req.func_tool

subset = self._build_tool_subset(full_tool_set, tool_names)
if not subset.tools:
return llm_resp, full_tool_set

if isinstance(self._tool_schema_param_set, ToolSet):
param_subset = self._build_tool_subset(
self._tool_schema_param_set, tool_names
)
if param_subset.tools:
requery_resp = await self._requery_tool_calls(param_subset, tool_names)
if requery_resp:
llm_resp = requery_resp

return llm_resp, subset

async def _requery_tool_calls(
self, tool_set: ToolSet, tool_names: list[str]
) -> LLMResponse | None:
if not tool_set or not tool_names:
return None
contexts = self._build_tool_requery_context(tool_names)
return await self.provider.text_chat(
prompt=None,
contexts=contexts,
func_tool=tool_set,
model=self.req.model,
session_id=self.req.session_id,
)

@override
async def reset(
self,
Expand All @@ -64,6 +163,7 @@ async def reset(
# customize
custom_token_counter: TokenCounter | None = None,
custom_compressor: ContextCompressor | None = None,
tool_schema_mode: str | None = None,
**kwargs: T.Any,
) -> None:
self.req = request
Expand Down Expand Up @@ -98,6 +198,8 @@ async def reset(
self.tool_executor = tool_executor
self.agent_hooks = agent_hooks
self.run_context = run_context
self._tool_schema_full_set = None
self._tool_schema_param_set = None

messages = []
# append existing messages in the run context
Expand All @@ -112,6 +214,7 @@ async def reset(
Message(role="system", content=request.system_prompt),
)
self.run_context.messages = messages
self._apply_tool_schema_mode(tool_schema_mode)

self.stats = AgentStats()
self.stats.start_time = time.time()
Expand Down Expand Up @@ -253,8 +356,14 @@ async def step(self):

# 如果有工具调用,还需处理工具调用
if llm_resp.tools_call_name:
llm_resp, exec_tool_set = await self._resolve_tool_exec(llm_resp)

tool_call_result_blocks = []
async for result in self._handle_function_tools(self.req, llm_resp):
async for result in self._handle_function_tools(
self.req,
llm_resp,
tool_set=exec_tool_set,
):
if isinstance(result, list):
tool_call_result_blocks = result
elif isinstance(result, MessageChain):
Expand All @@ -269,6 +378,7 @@ async def step(self):
type=ar_type,
data=AgentResponseData(chain=result),
)

# 将结果添加到上下文中
parts = []
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
Expand Down Expand Up @@ -327,6 +437,7 @@ async def _handle_function_tools(
self,
req: ProviderRequest,
llm_response: LLMResponse,
tool_set: ToolSet | None = None,
) -> T.AsyncGenerator[MessageChain | list[ToolCallMessageSegment], None]:
"""处理函数工具调用。"""
tool_call_result_blocks: list[ToolCallMessageSegment] = []
Expand All @@ -352,9 +463,8 @@ async def _handle_function_tools(
],
)
try:
if not req.func_tool:
return
func_tool = req.func_tool.get_func(func_tool_name)
active_tool_set = tool_set or req.func_tool
func_tool = active_tool_set.get_tool(func_tool_name)
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")

if not func_tool:
Expand Down
76 changes: 56 additions & 20 deletions astrbot/core/agent/tool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any, Generic

Expand Down Expand Up @@ -102,6 +103,47 @@ def get_tool(self, name: str) -> FunctionTool | None:
return tool
return None

def get_light_tool_set(self) -> "ToolSet":
"""Return a light tool set with only name/description."""
light_tools = []
for tool in self.tools:
if hasattr(tool, "active") and not tool.active:
continue
light_params = {
"type": "object",
"properties": {},
}
light_tools.append(
FunctionTool(
name=tool.name,
parameters=light_params,
description=tool.description,
handler=None,
)
)
return ToolSet(light_tools)

def get_param_only_tool_set(self) -> "ToolSet":
"""Return a tool set with name/parameters only (no description)."""
param_tools = []
for tool in self.tools:
if hasattr(tool, "active") and not tool.active:
continue
params = (
copy.deepcopy(tool.parameters)
if tool.parameters
else {"type": "object", "properties": {}}
)
param_tools.append(
FunctionTool(
name=tool.name,
parameters=params,
description="",
handler=None,
)
)
return ToolSet(param_tools)

@deprecated(reason="Use add_tool() instead", version="4.0.0")
def add_func(
self,
Expand Down Expand Up @@ -147,18 +189,15 @@ def openai_schema(self, omit_empty_parameter_field: bool = False) -> list[dict]:
"""Convert tools to OpenAI API function calling schema format."""
result = []
for tool in self.tools:
func_def = {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
},
}
func_def = {"type": "function", "function": {"name": tool.name}}
if tool.description:
func_def["function"]["description"] = tool.description

if (
tool.parameters and tool.parameters.get("properties")
) or not omit_empty_parameter_field:
func_def["function"]["parameters"] = tool.parameters
if tool.parameters is not None:
if (
tool.parameters and tool.parameters.get("properties")
) or not omit_empty_parameter_field:
func_def["function"]["parameters"] = tool.parameters

result.append(func_def)
return result
Expand All @@ -171,11 +210,9 @@ def anthropic_schema(self) -> list[dict]:
if tool.parameters:
input_schema["properties"] = tool.parameters.get("properties", {})
input_schema["required"] = tool.parameters.get("required", [])
tool_def = {
"name": tool.name,
"description": tool.description,
"input_schema": input_schema,
}
tool_def = {"name": tool.name, "input_schema": input_schema}
if tool.description:

This comment was marked as outdated.

tool_def["description"] = tool.description
result.append(tool_def)
return result

Expand Down Expand Up @@ -245,10 +282,9 @@ def convert_schema(schema: dict) -> dict:

tools = []
for tool in self.tools:
d: dict[str, Any] = {
"name": tool.name,
"description": tool.description,
}
d: dict[str, Any] = {"name": tool.name}
if tool.description:
d["description"] = tool.description
if tool.parameters:
d["parameters"] = convert_schema(tool.parameters)
tools.append(d)
Expand Down
14 changes: 14 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
"reachability_check": False,
"max_agent_step": 30,
"tool_call_timeout": 60,
"tool_schema_mode": "full",
"llm_safety_mode": True,
"safety_mode_strategy": "system_prompt", # TODO: llm judge
"file_extract": {
Expand Down Expand Up @@ -2188,6 +2189,9 @@ class ChatProviderTemplate(TypedDict):
"tool_call_timeout": {
"type": "int",
},
"tool_schema_mode": {
"type": "string",
},
"file_extract": {
"type": "object",
"items": {
Expand Down Expand Up @@ -2779,6 +2783,16 @@ class ChatProviderTemplate(TypedDict):
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.tool_schema_mode": {
"description": "工具调用模式",
"type": "string",
"options": ["skills_like", "full"],
"labels": ["Skills-like(两阶段)", "Full(完整参数)"],
"hint": "skills-like 先下发工具名称与描述,再下发参数;full 一次性下发完整参数。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.wake_prefix": {
"description": "LLM 聊天额外唤醒前缀 ",
"type": "string",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
PYTHON_TOOL,
SANDBOX_MODE_PROMPT,
TOOL_CALL_PROMPT,
TOOL_CALL_PROMPT_FULL,
decoded_blocked,
retrieve_knowledge_base,
)
Expand All @@ -62,6 +63,13 @@ async def initialize(self, ctx: PipelineContext) -> None:
]
self.max_step: int = settings.get("max_agent_step", 30)
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
self.tool_schema_mode: str = settings.get("tool_schema_mode", "skills_like")
if self.tool_schema_mode not in ("skills_like", "full"):
logger.warning(
"Unsupported tool_schema_mode: %s, fallback to skills_like",
self.tool_schema_mode,
)
self.tool_schema_mode = "skills_like"
if isinstance(self.max_step, bool): # workaround: #2622
self.max_step = 30
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
Expand Down Expand Up @@ -672,7 +680,12 @@ async def process(

# 注入基本 prompt
if req.func_tool and req.func_tool.tools:
req.system_prompt += f"\n{TOOL_CALL_PROMPT}\n"
tool_prompt = (
TOOL_CALL_PROMPT_FULL
if self.tool_schema_mode == "full"
else TOOL_CALL_PROMPT
)
req.system_prompt += f"\n{tool_prompt}\n"

action_type = event.get_extra("action_type")
if action_type == "live":
Expand All @@ -693,6 +706,7 @@ async def process(
llm_compress_provider=self._get_compress_provider(),
truncate_turns=self.dequeue_context_length,
enforce_max_turns=self.max_context_length,
tool_schema_mode=self.tool_schema_mode,
)

# 检测 Live Mode
Expand Down
17 changes: 14 additions & 3 deletions astrbot/core/pipeline/process_stage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,20 @@

TOOL_CALL_PROMPT = (
"You MUST NOT return an empty response, especially after invoking a tool."
"Before calling any tool, provide a brief explanatory message to the user stating the purpose of the tool call."
"After the tool call is completed, you must briefly summarize the results returned by the tool for the user."
"Keep the role-play and style consistent throughout the conversation."
" Before calling any tool, provide a brief explanatory message to the user stating the purpose of the tool call."
" Tool schemas are provided in two stages: first only name and description; "
"if you decide to use a tool, the full parameter schema will be provided in "
"a follow-up step. Do not guess arguments before you see the schema."
" After the tool call is completed, you must briefly summarize the results returned by the tool for the user."
" Keep the role-play and style consistent throughout the conversation."
)

TOOL_CALL_PROMPT_FULL = (
"You MUST NOT return an empty response, especially after invoking a tool."
" Before calling any tool, provide a brief explanatory message to the user stating the purpose of the tool call."
" Use the provided tool schema to format arguments and do not guess parameters that are not defined."
" After the tool call is completed, you must briefly summarize the results returned by the tool for the user."
" Keep the role-play and style consistent throughout the conversation."
)

CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT = (
Expand Down
Loading