|
23 | 23 | ) |
24 | 24 | from agent_framework._mcp import ( |
25 | 25 | MCPTool, |
| 26 | + _classify_mcp_tool_failure, |
26 | 27 | _get_input_model_from_mcp_prompt, |
27 | 28 | _normalize_mcp_name, |
28 | 29 | _parse_content_from_mcp, |
@@ -53,6 +54,12 @@ def test_normalize_mcp_name(): |
53 | 54 | assert _normalize_mcp_name("name/with\\slashes") == "name-with-slashes" |
54 | 55 |
|
55 | 56 |
|
| 57 | +def test_classify_mcp_tool_failure(): |
| 58 | + assert _classify_mcp_tool_failure("Tool not found on remote server") == "mcp_tool_missing" |
| 59 | + assert _classify_mcp_tool_failure("Invalid params for schema validation") == "mcp_tool_schema_mismatch" |
| 60 | + assert _classify_mcp_tool_failure("transport closed") is None |
| 61 | + |
| 62 | + |
56 | 63 | def test_mcp_transport_subclasses_accept_tool_name_prefix() -> None: |
57 | 64 | assert MCPStdioTool(name="stdio", command="python", tool_name_prefix="stdio").tool_name_prefix == "stdio" |
58 | 65 | assert ( |
@@ -1032,6 +1039,49 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: |
1032 | 1039 | await func.invoke(param="test_value") |
1033 | 1040 |
|
1034 | 1041 |
|
| 1042 | +async def test_local_mcp_server_schema_drift_error_is_classified_and_refreshes(): |
| 1043 | + """Schema drift should fail closed with a stable marker and trigger a tool refresh.""" |
| 1044 | + |
| 1045 | + class TestServer(MCPTool): |
| 1046 | + async def connect(self): |
| 1047 | + self.session = Mock(spec=ClientSession) |
| 1048 | + self.session.list_tools = AsyncMock( |
| 1049 | + return_value=types.ListToolsResult( |
| 1050 | + tools=[ |
| 1051 | + types.Tool( |
| 1052 | + name="test_tool", |
| 1053 | + description="Test tool", |
| 1054 | + inputSchema={ |
| 1055 | + "type": "object", |
| 1056 | + "properties": {"param": {"type": "string"}}, |
| 1057 | + "required": ["param"], |
| 1058 | + }, |
| 1059 | + ) |
| 1060 | + ] |
| 1061 | + ) |
| 1062 | + ) |
| 1063 | + self.session.call_tool = AsyncMock( |
| 1064 | + side_effect=McpError(types.ErrorData(code=-32602, message="Invalid params: schema changed")) |
| 1065 | + ) |
| 1066 | + |
| 1067 | + def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: |
| 1068 | + return None |
| 1069 | + |
| 1070 | + server = TestServer(name="test_server") |
| 1071 | + async with server: |
| 1072 | + await server.load_tools() |
| 1073 | + func = server.functions[0] |
| 1074 | + |
| 1075 | + with ( |
| 1076 | + patch.object(server, "connect", new_callable=AsyncMock) as mock_connect, |
| 1077 | + pytest.raises(ToolExecutionException, match=r"\[mcp_tool_schema_mismatch\]") as exc_info, |
| 1078 | + ): |
| 1079 | + await func.invoke(param="test_value") |
| 1080 | + |
| 1081 | + mock_connect.assert_awaited_once_with(reset=True) |
| 1082 | + assert "schema changed" in str(exc_info.value) |
| 1083 | + |
| 1084 | + |
1035 | 1085 | async def test_mcp_tool_call_tool_raises_on_is_error(): |
1036 | 1086 | """Test that call_tool raises ToolExecutionException when MCP returns isError=True.""" |
1037 | 1087 |
|
|
0 commit comments