Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
481 changes: 419 additions & 62 deletions astrbot/builtin_stars/web_searcher/main.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion astrbot/core/astr_agent_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.pipeline.context_utils import call_event_hook
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.web_search_utils import WEB_SEARCH_REFERENCE_TOOLS


class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
Expand Down Expand Up @@ -59,7 +60,7 @@ async def on_tool_end(
platform_name = run_context.context.event.get_platform_name()
if (
platform_name == "webchat"
and tool.name in ["web_search_tavily", "web_search_bocha"]
and tool.name in WEB_SEARCH_REFERENCE_TOOLS
and len(run_context.messages) > 0
and tool_result
and len(tool_result.content)
Expand Down
39 changes: 38 additions & 1 deletion astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,11 @@
"web_search": False,
"websearch_provider": "default",
"websearch_tavily_key": [],
"websearch_tavily_base_url": "https://api.tavily.com",
"websearch_bocha_key": [],
"websearch_baidu_app_builder_key": "",
"websearch_exa_key": [],
"websearch_exa_base_url": "https://api.exa.ai",
"web_search_link": False,
"display_reasoning_text": False,
"identifier": False,
Expand Down Expand Up @@ -3084,7 +3087,13 @@ class ChatProviderTemplate(TypedDict):
"provider_settings.websearch_provider": {
"description": "网页搜索提供商",
"type": "string",
"options": ["default", "tavily", "baidu_ai_search", "bocha"],
"options": [
"default",
"tavily",
"baidu_ai_search",
"bocha",
"exa",
],
"condition": {
"provider_settings.web_search": True,
},
Expand Down Expand Up @@ -3117,6 +3126,34 @@ class ChatProviderTemplate(TypedDict):
"provider_settings.websearch_provider": "baidu_ai_search",
},
},
"provider_settings.websearch_tavily_base_url": {
"description": "Tavily API Base URL",
"type": "string",
"hint": "默认为 https://api.tavily.com,可改为代理地址。",
"condition": {
"provider_settings.websearch_provider": "tavily",
"provider_settings.web_search": True,
},
},
"provider_settings.websearch_exa_key": {
"description": "Exa API Key",
"type": "list",
"items": {"type": "string"},
"hint": "可添加多个 Key 进行轮询。",
"condition": {
"provider_settings.websearch_provider": "exa",
"provider_settings.web_search": True,
},
},
"provider_settings.websearch_exa_base_url": {
"description": "Exa API Base URL",
"type": "string",
"hint": "默认为 https://api.exa.ai,可改为代理地址。",
"condition": {
"provider_settings.websearch_provider": "exa",
"provider_settings.web_search": True,
},
},
"provider_settings.web_search_link": {
"description": "显示来源引用",
"type": "bool",
Expand Down
8 changes: 7 additions & 1 deletion astrbot/core/knowledge_base/kb_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,12 +518,18 @@ async def upload_from_url(
"Error: Tavily API key is not configured in provider_settings."
)

tavily_base_url = config.get("provider_settings", {}).get(
"websearch_tavily_base_url", "https://api.tavily.com"
)

# 阶段1: 从 URL 提取内容
if progress_callback:
await progress_callback("extracting", 0, 100)

try:
text_content = await extract_text_from_url(url, tavily_keys)
text_content = await extract_text_from_url(
url, tavily_keys, tavily_base_url
)
except Exception as e:
logger.error(f"Failed to extract content from URL {url}: {e}")
raise OSError(f"Failed to extract content from URL {url}: {e}") from e
Expand Down
26 changes: 21 additions & 5 deletions astrbot/core/knowledge_base/parsers/url_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,33 @@

import aiohttp

from astrbot.core.utils.web_search_utils import normalize_web_search_base_url


class URLExtractor:
"""URL 内容提取器,封装了 Tavily API 调用和密钥管理"""

def __init__(self, tavily_keys: list[str]) -> None:
def __init__(
self, tavily_keys: list[str], tavily_base_url: str = "https://api.tavily.com"
) -> None:
"""
初始化 URL 提取器

Args:
tavily_keys: Tavily API 密钥列表
tavily_base_url: Tavily API 基础 URL
"""
if not tavily_keys:
raise ValueError("Error: Tavily API keys are not configured.")

self.tavily_keys = tavily_keys
self.tavily_key_index = 0
self.tavily_key_lock = asyncio.Lock()
self.tavily_base_url = normalize_web_search_base_url(
tavily_base_url,
default="https://api.tavily.com",
provider_name="Tavily",
)

async def _get_tavily_key(self) -> str:
"""并发安全的从列表中获取并轮换Tavily API密钥。"""
Expand Down Expand Up @@ -47,7 +57,7 @@ async def extract_text_from_url(self, url: str) -> str:
raise ValueError("Error: url must be a non-empty string.")

tavily_key = await self._get_tavily_key()
api_url = "https://api.tavily.com/extract"
api_url = f"{self.tavily_base_url}/extract"
headers = {
"Authorization": f"Bearer {tavily_key}",
"Content-Type": "application/json",
Expand All @@ -69,7 +79,10 @@ async def extract_text_from_url(self, url: str) -> str:
if response.status != 200:
reason = await response.text()
raise OSError(
f"Tavily web extraction failed: {reason}, status: {response.status}"
f"Tavily web extraction failed for URL {api_url}: "
f"{reason}, status: {response.status}. If you configured "
"a Tavily API Base URL, make sure it is a base URL or "
"proxy prefix rather than a specific endpoint path."
)

data = await response.json()
Expand All @@ -88,16 +101,19 @@ async def extract_text_from_url(self, url: str) -> str:


# 为了向后兼容,提供一个简单的函数接口
async def extract_text_from_url(url: str, tavily_keys: list[str]) -> str:
async def extract_text_from_url(
url: str, tavily_keys: list[str], tavily_base_url: str = "https://api.tavily.com"
) -> str:
"""
简单的函数接口,用于从 URL 提取文本内容

Args:
url: 要提取内容的网页 URL
tavily_keys: Tavily API 密钥列表
tavily_base_url: Tavily API 基础 URL

Returns:
提取的文本内容
"""
extractor = URLExtractor(tavily_keys)
extractor = URLExtractor(tavily_keys, tavily_base_url)
return await extractor.extract_text_from_url(url)
131 changes: 131 additions & 0 deletions astrbot/core/utils/web_search_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import json
import re
from typing import Any
from urllib.parse import urlparse

WEB_SEARCH_REFERENCE_TOOLS = (
"web_search_tavily",
"web_search_bocha",
"web_search_exa",
"exa_find_similar",
)


def normalize_web_search_base_url(
base_url: str | None,
*,
default: str,
provider_name: str,
) -> str:
normalized = (base_url or "").strip()
if not normalized:
normalized = default
normalized = normalized.rstrip("/")

parsed = urlparse(normalized)
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
raise ValueError(
f"Error: {provider_name} API Base URL must start with http:// or "
f"https://. Proxy base paths are allowed. Received: {normalized!r}.",
)
return normalized


def _iter_web_search_result_items(
accumulated_parts: list[dict[str, Any]],
):
for part in accumulated_parts:
if part.get("type") != "tool_call" or not part.get("tool_calls"):
continue

for tool_call in part["tool_calls"]:
if tool_call.get(
"name"
) not in WEB_SEARCH_REFERENCE_TOOLS or not tool_call.get("result"):
continue

result = tool_call["result"]
try:
result_data = json.loads(result) if isinstance(result, str) else result
except json.JSONDecodeError:
continue

if not isinstance(result_data, dict):
continue

for item in result_data.get("results", []):
if isinstance(item, dict):
yield item


def _extract_ref_indices(accumulated_text: str) -> list[str]:
ref_indices: list[str] = []
seen_indices: set[str] = set()

for match in re.finditer(r"<ref>(.*?)</ref>", accumulated_text):
ref_index = match.group(1).strip()
if not ref_index or ref_index in seen_indices:
continue
ref_indices.append(ref_index)
seen_indices.add(ref_index)

return ref_indices


def collect_web_search_ref_items(
accumulated_parts: list[dict[str, Any]],
favicon_cache: dict[str, str] | None = None,
) -> list[dict[str, Any]]:
web_search_refs: list[dict[str, Any]] = []
seen_indices: set[str] = set()

for item in _iter_web_search_result_items(accumulated_parts):
ref_index = item.get("index")
if not ref_index or ref_index in seen_indices:
continue

payload = {
"index": ref_index,
"url": item.get("url"),
"title": item.get("title"),
"snippet": item.get("snippet"),
}
if favicon_cache and payload["url"] in favicon_cache:
payload["favicon"] = favicon_cache[payload["url"]]

web_search_refs.append(payload)
seen_indices.add(ref_index)

return web_search_refs


def build_web_search_refs(
accumulated_text: str,
accumulated_parts: list[dict[str, Any]],
favicon_cache: dict[str, str] | None = None,
) -> dict:
ordered_refs = collect_web_search_ref_items(accumulated_parts, favicon_cache)
if not ordered_refs:
return {}

refs_by_index = {ref["index"]: ref for ref in ordered_refs}
ref_indices = _extract_ref_indices(accumulated_text)
used_refs = [refs_by_index[idx] for idx in ref_indices if idx in refs_by_index]

if not used_refs:
used_refs = ordered_refs

return {"used": used_refs}


def collect_web_search_results(accumulated_parts: list[dict[str, Any]]) -> dict:
web_search_results = {}

for ref in collect_web_search_ref_items(accumulated_parts):
web_search_results[ref["index"]] = {
"url": ref.get("url"),
"title": ref.get("title"),
"snippet": ref.get("snippet"),
}

return web_search_results
Loading
Loading