Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion astrbot/builtin_stars/astrbot/process_llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
from astrbot.core.agent.message import TextPart
from astrbot.core.pipeline.process_stage.utils import (
CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT,
LOCAL_EXECUTE_SHELL_TOOL,
LOCAL_PYTHON_TOOL,
)
from astrbot.core.provider.func_tool_manager import ToolSet
from astrbot.core.skills.skill_manager import SkillManager, build_skills_prompt


class ProcessLLMRequest:
Expand All @@ -25,6 +28,15 @@ def __init__(self, context: star.Context):
else:
logger.info(f"Timezone set to: {self.timezone}")

self.skill_manager = SkillManager()

def _apply_local_env_tools(self, req: ProviderRequest) -> None:
"""Add local environment tools to the provider request."""
if req.func_tool is None:
req.func_tool = ToolSet()
req.func_tool.add_tool(LOCAL_EXECUTE_SHELL_TOOL)
req.func_tool.add_tool(LOCAL_PYTHON_TOOL)

async def _ensure_persona(
self, req: ProviderRequest, cfg: dict, umo: str, platform_type: str
):
Expand Down Expand Up @@ -66,6 +78,30 @@ async def _ensure_persona(
if begin_dialogs := copy.deepcopy(persona["_begin_dialogs_processed"]):
req.contexts[:0] = begin_dialogs

# skills select and prompt
runtime = self.skills_cfg.get("runtime", "local")
skills = self.skill_manager.list_skills(active_only=True, runtime=runtime)
if runtime == "sandbox" and not self.sandbox_cfg.get("enable", False):
logger.warning(
"Skills runtime is set to sandbox, but sandbox mode is disabled, will skip skills prompt injection.",
)
req.system_prompt += "\n[Background: User added some skills, and skills runtime is set to sandbox, but sandbox mode is disabled. So skills will be unavailable.]\n"
elif skills:
# persona.skills == None means all skills are allowed
if persona and persona.get("skills") is not None:
if not persona["skills"]:
return
allowed = set(persona["skills"])
skills = [skill for skill in skills if skill.name in allowed]
if skills:
req.system_prompt += f"\n{build_skills_prompt(skills)}\n"

# if user wants to use skills in non-sandbox mode, apply local env tools
runtime = self.skills_cfg.get("runtime", "local")
sandbox_enabled = self.sandbox_cfg.get("enable", False)
if runtime == "local" and not sandbox_enabled:
self._apply_local_env_tools(req)

# tools select
tmgr = self.ctx.get_llm_tool_manager()
if (persona and persona.get("tools") is None) or not persona:
Expand All @@ -81,7 +117,10 @@ async def _ensure_persona(
tool = tmgr.get_func(tool_name)
if tool and tool.active:
toolset.add_tool(tool)
req.func_tool = toolset
if not req.func_tool:
req.func_tool = toolset
else:
req.func_tool.merge(toolset)
logger.debug(f"Tool set for persona {persona_id}: {toolset.names()}")

async def _ensure_img_caption(
Expand Down Expand Up @@ -134,6 +173,8 @@ async def process_llm_request(self, event: AstrMessageEvent, req: ProviderReques
cfg: dict = self.ctx.get_config(umo=event.unified_msg_origin)[
"provider_settings"
]
self.skills_cfg = cfg.get("skills", {})
self.sandbox_cfg = cfg.get("sandbox", {})

# prompt prefix
if prefix := cfg.get("prompt_prefix"):
Expand Down
5 changes: 5 additions & 0 deletions astrbot/core/agent/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,11 @@ def names(self) -> list[str]:
"""获取所有工具的名称列表"""
return [tool.name for tool in self.tools]

def merge(self, other: "ToolSet"):
"""Merge another ToolSet into this one."""
for tool in other.tools:
self.add_tool(tool)

def __len__(self):
return len(self.tools)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ..olayer import FileSystemComponent, PythonComponent, ShellComponent


class SandboxBooter:
class ComputerBooter:
@property
def fs(self) -> FileSystemComponent: ...

Expand All @@ -16,16 +16,16 @@ async def boot(self, session_id: str) -> None: ...
async def shutdown(self) -> None: ...

async def upload_file(self, path: str, file_name: str) -> dict:
"""Upload file to sandbox.
"""Upload file to the computer.

Should return a dict with `success` (bool) and `file_path` (str) keys.
"""
...

async def download_file(self, remote_path: str, local_path: str):
"""Download file from sandbox."""
"""Download file from the computer."""
...

async def available(self) -> bool:
"""Check if the sandbox is available."""
"""Check if the computer is available."""
...
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from astrbot.api import logger

from ..olayer import FileSystemComponent, PythonComponent, ShellComponent
from .base import SandboxBooter
from .base import ComputerBooter


class MockShipyardSandboxClient:
Expand Down Expand Up @@ -124,7 +124,7 @@ async def wait_healthy(self, ship_id: str, session_id: str) -> None:
loop -= 1


class BoxliteBooter(SandboxBooter):
class BoxliteBooter(ComputerBooter):
async def boot(self, session_id: str) -> None:
logger.info(
f"Booting(Boxlite) for session: {session_id}, this may take a while..."
Expand Down
234 changes: 234 additions & 0 deletions astrbot/core/computer/booters/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
from __future__ import annotations

import asyncio
import os
import shutil
import subprocess
import sys
from dataclasses import dataclass
from typing import Any

from astrbot.api import logger
from astrbot.core.utils.astrbot_path import (
get_astrbot_data_path,
get_astrbot_root,
get_astrbot_temp_path,
)

from ..olayer import FileSystemComponent, PythonComponent, ShellComponent
from .base import ComputerBooter

_BLOCKED_COMMAND_PATTERNS = [
" rm -rf ",
" rm -fr ",
" rm -r ",
" mkfs",
" dd if=",
" shutdown",
" reboot",
" poweroff",
" halt",
" sudo ",
":(){:|:&};:",
" kill -9 ",
" killall ",
]


def _is_safe_command(command: str) -> bool:
cmd = f" {command.strip().lower()} "
return not any(pat in cmd for pat in _BLOCKED_COMMAND_PATTERNS)


def _ensure_safe_path(path: str) -> str:
abs_path = os.path.abspath(path)
allowed_roots = [
os.path.abspath(get_astrbot_root()),
os.path.abspath(get_astrbot_data_path()),
os.path.abspath(get_astrbot_temp_path()),
]
if not any(abs_path.startswith(root) for root in allowed_roots):
raise PermissionError("Path is outside the allowed computer roots.")
return abs_path


@dataclass
class LocalShellComponent(ShellComponent):
async def exec(
self,
command: str,
cwd: str | None = None,
env: dict[str, str] | None = None,
timeout: int | None = 30,
shell: bool = True,
background: bool = False,
) -> dict[str, Any]:
if not _is_safe_command(command):
raise PermissionError("Blocked unsafe shell command.")

def _run() -> dict[str, Any]:
run_env = os.environ.copy()
if env:
run_env.update({str(k): str(v) for k, v in env.items()})
working_dir = _ensure_safe_path(cwd) if cwd else get_astrbot_root()
if background:
proc = subprocess.Popen(
command,
shell=shell,
cwd=working_dir,
env=run_env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
return {"pid": proc.pid, "stdout": "", "stderr": "", "exit_code": None}
result = subprocess.run(
command,
shell=shell,
cwd=working_dir,
env=run_env,
timeout=timeout,
capture_output=True,
text=True,
)
return {
"stdout": result.stdout,
"stderr": result.stderr,
"exit_code": result.returncode,
}

return await asyncio.to_thread(_run)


@dataclass
class LocalPythonComponent(PythonComponent):
async def exec(
self,
code: str,
kernel_id: str | None = None,
timeout: int = 30,
silent: bool = False,
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
try:
result = subprocess.run(
[os.environ.get("PYTHON", sys.executable), "-c", code],
timeout=timeout,
capture_output=True,
text=True,
)
stdout = "" if silent else result.stdout
stderr = result.stderr if result.returncode != 0 else ""
return {
"data": {
"output": {"text": stdout, "images": []},
"error": stderr,
}
}
except subprocess.TimeoutExpired:
return {
"data": {
"output": {"text": "", "images": []},
"error": "Execution timed out.",
}
}

return await asyncio.to_thread(_run)


@dataclass
class LocalFileSystemComponent(FileSystemComponent):
async def create_file(
self, path: str, content: str = "", mode: int = 0o644
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = _ensure_safe_path(path)
os.makedirs(os.path.dirname(abs_path), exist_ok=True)
with open(abs_path, "w", encoding="utf-8") as f:
f.write(content)
os.chmod(abs_path, mode)
return {"success": True, "path": abs_path}

return await asyncio.to_thread(_run)

async def read_file(self, path: str, encoding: str = "utf-8") -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = _ensure_safe_path(path)
with open(abs_path, encoding=encoding) as f:
content = f.read()
return {"success": True, "content": content}

return await asyncio.to_thread(_run)

async def write_file(
self, path: str, content: str, mode: str = "w", encoding: str = "utf-8"
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = _ensure_safe_path(path)
os.makedirs(os.path.dirname(abs_path), exist_ok=True)
with open(abs_path, mode, encoding=encoding) as f:
f.write(content)
return {"success": True, "path": abs_path}

return await asyncio.to_thread(_run)

async def delete_file(self, path: str) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = _ensure_safe_path(path)
if os.path.isdir(abs_path):
shutil.rmtree(abs_path)
else:
os.remove(abs_path)
return {"success": True, "path": abs_path}

return await asyncio.to_thread(_run)

async def list_dir(
self, path: str = ".", show_hidden: bool = False
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = _ensure_safe_path(path)
entries = os.listdir(abs_path)
if not show_hidden:
entries = [e for e in entries if not e.startswith(".")]
return {"success": True, "entries": entries}

return await asyncio.to_thread(_run)


class LocalBooter(ComputerBooter):
def __init__(self) -> None:
self._fs = LocalFileSystemComponent()
self._python = LocalPythonComponent()
self._shell = LocalShellComponent()

async def boot(self, session_id: str) -> None:
logger.info(f"Local computer booter initialized for session: {session_id}")

async def shutdown(self) -> None:
logger.info("Local computer booter shutdown complete.")

@property
def fs(self) -> FileSystemComponent:
return self._fs

@property
def python(self) -> PythonComponent:
return self._python

@property
def shell(self) -> ShellComponent:
return self._shell

async def upload_file(self, path: str, file_name: str) -> dict:
raise NotImplementedError(
"LocalBooter does not support upload_file operation. Use shell instead."
)

async def download_file(self, remote_path: str, local_path: str):
raise NotImplementedError(
"LocalBooter does not support download_file operation. Use shell instead."
)

async def available(self) -> bool:
return True
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from astrbot.api import logger

from ..olayer import FileSystemComponent, PythonComponent, ShellComponent
from .base import SandboxBooter
from .base import ComputerBooter


class ShipyardBooter(SandboxBooter):
class ShipyardBooter(ComputerBooter):
def __init__(
self,
endpoint_url: str,
Expand Down
Loading