Skip to content

Commit cbb827f

Browse files
committed
feat: MCP Server
1 parent a4510b6 commit cbb827f

5 files changed

Lines changed: 98 additions & 19 deletions

File tree

docs/tutorials/agent/agent.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ assistant.add_too_functions(get_weather)
7575
当然,在函数定义时就可以声明其为一个工具函数:
7676

7777
```python
78-
@assistant.tool
78+
@assistant.tool()
7979
def get_weather(location):
8080
...
8181
```

src/course_graph/agent/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
from .agent import Agent
99
from .controller import Controller
1010
from .types import Result, ContextVariables, TraceEvent
11+
from .mcp import MCPServer

src/course_graph/agent/agent.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,34 +13,44 @@
1313
from typing import Callable
1414
from typing import Literal
1515
from openai import NOT_GIVEN, NotGiven
16+
from .mcp import MCPServer
17+
import asyncio
1618

1719

1820
class Agent:
1921

2022
def __init__(
2123
self,
22-
name: str,
2324
llm: LLMBase,
25+
name: str = 'Assistant',
2426
functions: list[Callable] = None,
2527
tool_choice: str | NotGiven | Literal['required', 'auto', 'none'] = NOT_GIVEN,
2628
parallel_tool_calls: bool | NotGiven = NOT_GIVEN,
27-
instruction: str | Callable[[ContextVariables], str] | Callable[[], str] = 'You are a helpful assistant.'
29+
instruction: str | Callable[[ContextVariables], str] | Callable[[], str] = 'You are a helpful assistant.',
30+
mcp_server: list[MCPServer] = None,
31+
mcp_impl: Literal['function_call'] = 'function_call'
2832
) -> None:
2933
""" 智能体类
3034
31-
Args: name (str): 名称
32-
llm (LLMBase): 大模型
33-
functions: (list[Callable], optional): 工具函数. Defaults to None.
34-
parallel_tool_calls: (bool, optional): 允许工具并行调用. Defaults to False.
35-
tool_choice: (Literal['required', 'auto', 'none'] | NotGiven, optional). 强制使用工具函数, 选择模式或提供函数名称. Defaults to NOT_GIVEN.
36-
instruction (str | Callable[[ContextVariables], str] | Callable[[], str], optional): 指令. Defaults to 'You are a helpful assistant.'.
35+
Args:
36+
llm (LLMBase): 大模型
37+
name (str, optional): 名称. Defaults to 'Assistant'.
38+
functions: (list[Callable], optional): 工具函数. Defaults to None.
39+
parallel_tool_calls: (bool, optional): 允许工具并行调用. Defaults to False.
40+
tool_choice: (Literal['required', 'auto', 'none'] | NotGiven, optional). 强制使用工具函数, 选择模式或提供函数名称. Defaults to NOT_GIVEN.
41+
instruction (str | Callable[[ContextVariables], str] | Callable[[], str], optional): 指令. Defaults to 'You are a helpful assistant.'.
42+
mcp_server: (list[MCPServer], optional): MCP 服务器. Defaults to None.
43+
mcp_impl: (Literal['function_call'] | NotGiven, optional): MCP 协议实现方式, 目前只支持 'function_call'. Defaults to 'function_call'.
3744
"""
38-
self.name = name
3945
self.llm = llm
46+
self.name = name
4047
self.instruction = instruction
4148

42-
self.tools: list[ChatCompletionToolParam] = []
43-
self.tool_functions: dict[str, Callable] = {}
49+
self.tools: list[ChatCompletionToolParam] = [] # for LLM
50+
51+
self.tool_functions: dict[str, Callable] = {} # for local function call
52+
self.mcp_functions: dict[str, MCPServer] = {} # for remote function call
53+
4454
self.parallel_tool_calls = parallel_tool_calls
4555
self.use_context_variables: dict[str, str] = {} # 使用了上下文变量的函数以及相应的形参名称
4656
self.use_agent_variables: dict[str, str] = {} # 使用了Agent变量的函数以及相应的形参名称
@@ -59,6 +69,23 @@ def __init__(
5969
self.tool_choice = tool_choice
6070

6171
self.messages: list[ChatCompletionMessageParam] = []
72+
self.mcp_server = mcp_server
73+
74+
async def initialize(self):
75+
""" 等待初始化智能体
76+
"""
77+
for server in self.mcp_server:
78+
tools = await server.list_tools()
79+
for tool in tools:
80+
self.tools.append({
81+
'type': 'function',
82+
'function': {
83+
'name': tool.name,
84+
'description': tool.description,
85+
'parameters': tool.inputSchema
86+
}
87+
}) # 注意不能使用 add_tools 方法
88+
self.mcp_functions[tool.name] = server
6289

6390
def chat(self, message: str = None) -> ChatCompletionMessage:
6491
""" Agent 多轮对话
@@ -123,11 +150,13 @@ def add_tool_call_message(self, tool_content: str, tool_call_id: str) -> None:
123150
}
124151
self.messages.append(message)
125152

126-
def tool(self, function: Callable) -> Callable:
153+
def tool(self) -> Callable:
127154
""" 标记一个外部工具函数
128155
"""
129-
self.add_tool_functions(function)
130-
return function
156+
def wrapper(function: Callable) -> Callable:
157+
self.add_tool_functions(function)
158+
return function
159+
return wrapper
131160

132161
def add_tools(self, *tools: 'Tool') -> 'Agent':
133162
""" 添加外部工具

src/course_graph/agent/mcp.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# -*- coding: utf-8 -*-
2+
# Create Date: 2025/03/28
3+
# Author: wangtao <wangtao.cpu@gmail.com>
4+
# File Name: course_graph/agent/mcp.py
5+
# Description: MCP Server 和 Client 实现相关
6+
7+
from typing import Literal
8+
from mcp import ClientSession, StdioServerParameters
9+
from mcp.client.stdio import stdio_client
10+
from contextlib import AsyncExitStack
11+
from mcp.types import Tool
12+
13+
class MCPServer:
14+
def __init__(self, type: Literal['stdio'], command: str, args: list[str], envs: dict[str, str] = None):
15+
""" MCP 服务器
16+
17+
Args:
18+
type (Literal['stdio']): 服务器类型, 目前只支持 'stdio'
19+
command (str): 命令
20+
args (list[str]): 参数
21+
envs (dict[str, str], optional): 环境变量. Defaults to None.
22+
"""
23+
self.params = StdioServerParameters(command=command, args=args, envs=envs)
24+
self.stack = AsyncExitStack()
25+
self.session = None
26+
27+
async def __aenter__(self):
28+
stdio_transport = await self.stack.enter_async_context(stdio_client(self.params))
29+
self.stdio, self.write = stdio_transport
30+
self.session = await self.stack.enter_async_context(ClientSession(self.stdio, self.write))
31+
await self.session.initialize()
32+
return self
33+
34+
async def list_tools(self) -> list[Tool]:
35+
return (await self.session.list_tools()).tools
36+
37+
async def __aexit__(self, exc_type, exc_value, traceback):
38+
await self.stack.aclose()
39+
40+
41+
42+
43+

src/course_graph/agent/types.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,49 +50,55 @@ class Result:
5050
content: str = 'Function call successfully.'
5151
context_variables: ContextVariables | dict = ContextVariables()
5252
message: bool = True
53-
53+
5454
def __repr__(self):
5555
return str({
5656
'agent': self.agent.name,
5757
'content': self.content,
5858
'context_variables': self.context_variables,
5959
'message': self.message
6060
})
61-
61+
6262

6363
@dataclass
6464
class TraceEvent:
6565
timestamp: float
6666
agent_name: str
6767

68+
6869
@dataclass
6970
class TraceEventUserMessage(TraceEvent):
7071
message: str
7172

73+
7274
@dataclass
7375
class TraceEventAgentMessage(TraceEvent):
7476
message: str
7577

78+
7679
@dataclass
7780
class TraceEventAgentSwitch(TraceEvent):
7881
to_agent: str
79-
82+
83+
8084
@dataclass
8185
class TraceEventToolCall(TraceEvent):
8286
function: str
8387
arguments: Dict[str, Any]
8488

89+
8590
@dataclass
8691
class TraceEventToolResult(TraceEvent):
8792
function: str
8893
result: Any
8994

95+
9096
@dataclass
9197
class TraceEventContextUpdate(TraceEvent):
9298
old_context: ContextVariables
9399
new_context: ContextVariables
94100

95-
101+
96102
class Trace(TypedDict):
97103
trace_id: str
98104
events: List[TraceEvent]

0 commit comments

Comments
 (0)