Skip to content

Commit 0e40e12

Browse files
committed
mcp: fail closed on remote schema drift (#4723)
1 parent 1b7940c commit 0e40e12

4 files changed

Lines changed: 140 additions & 4 deletions

File tree

dotnet/src/Microsoft.Agents.AI.Workflows.Declarative.Mcp/DefaultMcpToolHandler.cs

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,18 @@ public async Task<McpServerToolResultContent> InvokeToolAsync(
6161
? null
6262
: arguments as IReadOnlyDictionary<string, object?> ?? new Dictionary<string, object?>(arguments);
6363

64-
CallToolResult result = await client.CallToolAsync(
65-
toolName,
66-
readOnlyArguments,
67-
cancellationToken: cancellationToken).ConfigureAwait(false);
64+
CallToolResult result;
65+
try
66+
{
67+
result = await client.CallToolAsync(
68+
toolName,
69+
readOnlyArguments,
70+
cancellationToken: cancellationToken).ConfigureAwait(false);
71+
}
72+
catch (Exception ex) when (TryClassifyToolInvocationFailure(ex.Message, out string? failureCode))
73+
{
74+
throw new InvalidOperationException($"[{failureCode}] {ex.Message}", ex);
75+
}
6876

6977
// Map MCP content blocks to MEAI AIContent types
7078
PopulateResultContent(resultContent, result);
@@ -183,6 +191,35 @@ private static string ComputeHeadersHash(IDictionary<string, string>? headers)
183191
return hashCode.ToString(CultureInfo.InvariantCulture);
184192
}
185193

194+
internal static bool TryClassifyToolInvocationFailure(string? message, out string? failureCode)
195+
{
196+
if (string.IsNullOrWhiteSpace(message))
197+
{
198+
failureCode = null;
199+
return false;
200+
}
201+
202+
string normalized = message.ToLowerInvariant();
203+
if (normalized.Contains("tool not found", StringComparison.Ordinal) ||
204+
normalized.Contains("unknown tool", StringComparison.Ordinal) ||
205+
normalized.Contains("no tool named", StringComparison.Ordinal))
206+
{
207+
failureCode = "mcp_tool_missing";
208+
return true;
209+
}
210+
211+
if (normalized.Contains("invalid params", StringComparison.Ordinal) ||
212+
normalized.Contains("schema", StringComparison.Ordinal) ||
213+
normalized.Contains("validation", StringComparison.Ordinal))
214+
{
215+
failureCode = "mcp_tool_schema_mismatch";
216+
return true;
217+
}
218+
219+
failureCode = null;
220+
return false;
221+
}
222+
186223
private static void PopulateResultContent(McpServerToolResultContent resultContent, CallToolResult result)
187224
{
188225
// Ensure Output list is initialized

dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.Mcp.UnitTests/DefaultMcpToolHandlerTests.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,5 +488,28 @@ public void ConvertContentBlock_AudioContentBlock_WithNullMimeType_ShouldDefault
488488
dataContent.MediaType.Should().Be("audio/*");
489489
}
490490

491+
[Theory]
492+
[InlineData("Tool not found on remote server", "mcp_tool_missing")]
493+
[InlineData("Invalid params: schema changed", "mcp_tool_schema_mismatch")]
494+
[InlineData("Request failed validation", "mcp_tool_schema_mismatch")]
495+
public void TryClassifyToolInvocationFailure_WithKnownSchemaOrToolMessages_ReturnsStableCode(
496+
string message,
497+
string expectedCode)
498+
{
499+
bool classified = DefaultMcpToolHandler.TryClassifyToolInvocationFailure(message, out string? code);
500+
501+
classified.Should().BeTrue();
502+
code.Should().Be(expectedCode);
503+
}
504+
505+
[Fact]
506+
public void TryClassifyToolInvocationFailure_WithUnrelatedMessage_ReturnsFalse()
507+
{
508+
bool classified = DefaultMcpToolHandler.TryClassifyToolInvocationFailure("Socket closed unexpectedly", out string? code);
509+
510+
classified.Should().BeFalse();
511+
code.Should().BeNull();
512+
}
513+
491514
#endregion
492515
}

python/packages/core/agent_framework/_mcp.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,17 @@ def _parse_content_from_mcp(
286286
return return_types
287287

288288

289+
def _classify_mcp_tool_failure(message: str) -> str | None:
290+
lowered = message.lower()
291+
292+
if "tool not found" in lowered or "unknown tool" in lowered or "no tool named" in lowered:
293+
return "mcp_tool_missing"
294+
if "invalid params" in lowered or "schema" in lowered or "validation" in lowered:
295+
return "mcp_tool_schema_mismatch"
296+
297+
return None
298+
299+
289300
def _prepare_content_for_mcp(
290301
content: Content,
291302
) -> types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink | None:
@@ -637,6 +648,9 @@ async def _connect_on_owner(self, *, reset: bool = False) -> None:
637648
self.session = None
638649
self.is_connected = False
639650
self._exit_stack = AsyncExitStack()
651+
self._functions = []
652+
self._tools_loaded = False
653+
self._prompts_loaded = False
640654
if not self.session:
641655
try:
642656
transport = await self._exit_stack.enter_async_context(self.get_mcp_client())
@@ -1054,6 +1068,18 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
10541068
inner_exception=cl_ex,
10551069
) from cl_ex
10561070
except McpError as mcp_exc:
1071+
failure_code = _classify_mcp_tool_failure(mcp_exc.error.message)
1072+
if failure_code is not None:
1073+
try:
1074+
await self.connect(reset=True)
1075+
except Exception:
1076+
logger.debug(
1077+
"Failed to refresh MCP tool definitions after classified tool failure.", exc_info=True
1078+
)
1079+
raise ToolExecutionException(
1080+
f"[{failure_code}] {mcp_exc.error.message}",
1081+
inner_exception=mcp_exc,
1082+
) from mcp_exc
10571083
raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc
10581084
except Exception as ex:
10591085
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex

python/packages/core/tests/core/test_mcp.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424
from agent_framework._mcp import (
2525
MCPTool,
26+
_classify_mcp_tool_failure,
2627
_get_input_model_from_mcp_prompt,
2728
_normalize_mcp_name,
2829
_parse_content_from_mcp,
@@ -53,6 +54,12 @@ def test_normalize_mcp_name():
5354
assert _normalize_mcp_name("name/with\\slashes") == "name-with-slashes"
5455

5556

57+
def test_classify_mcp_tool_failure():
58+
assert _classify_mcp_tool_failure("Tool not found on remote server") == "mcp_tool_missing"
59+
assert _classify_mcp_tool_failure("Invalid params for schema validation") == "mcp_tool_schema_mismatch"
60+
assert _classify_mcp_tool_failure("transport closed") is None
61+
62+
5663
def test_mcp_transport_subclasses_accept_tool_name_prefix() -> None:
5764
assert MCPStdioTool(name="stdio", command="python", tool_name_prefix="stdio").tool_name_prefix == "stdio"
5865
assert (
@@ -1032,6 +1039,49 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
10321039
await func.invoke(param="test_value")
10331040

10341041

1042+
async def test_local_mcp_server_schema_drift_error_is_classified_and_refreshes():
1043+
"""Schema drift should fail closed with a stable marker and trigger a tool refresh."""
1044+
1045+
class TestServer(MCPTool):
1046+
async def connect(self):
1047+
self.session = Mock(spec=ClientSession)
1048+
self.session.list_tools = AsyncMock(
1049+
return_value=types.ListToolsResult(
1050+
tools=[
1051+
types.Tool(
1052+
name="test_tool",
1053+
description="Test tool",
1054+
inputSchema={
1055+
"type": "object",
1056+
"properties": {"param": {"type": "string"}},
1057+
"required": ["param"],
1058+
},
1059+
)
1060+
]
1061+
)
1062+
)
1063+
self.session.call_tool = AsyncMock(
1064+
side_effect=McpError(types.ErrorData(code=-32602, message="Invalid params: schema changed"))
1065+
)
1066+
1067+
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
1068+
return None
1069+
1070+
server = TestServer(name="test_server")
1071+
async with server:
1072+
await server.load_tools()
1073+
func = server.functions[0]
1074+
1075+
with (
1076+
patch.object(server, "connect", new_callable=AsyncMock) as mock_connect,
1077+
pytest.raises(ToolExecutionException, match=r"\[mcp_tool_schema_mismatch\]") as exc_info,
1078+
):
1079+
await func.invoke(param="test_value")
1080+
1081+
mock_connect.assert_awaited_once_with(reset=True)
1082+
assert "schema changed" in str(exc_info.value)
1083+
1084+
10351085
async def test_mcp_tool_call_tool_raises_on_is_error():
10361086
"""Test that call_tool raises ToolExecutionException when MCP returns isError=True."""
10371087

0 commit comments

Comments
 (0)