Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,18 @@ public async Task<McpServerToolResultContent> InvokeToolAsync(
? null
: arguments as IReadOnlyDictionary<string, object?> ?? new Dictionary<string, object?>(arguments);

CallToolResult result = await client.CallToolAsync(
toolName,
readOnlyArguments,
cancellationToken: cancellationToken).ConfigureAwait(false);
CallToolResult result;
try
{
result = await client.CallToolAsync(
toolName,
readOnlyArguments,
cancellationToken: cancellationToken).ConfigureAwait(false);
}
catch (Exception ex) when (TryClassifyToolInvocationFailure(ex.Message, out string? failureCode))
{
throw new InvalidOperationException($"[{failureCode}] {ex.Message}", ex);
}

// Map MCP content blocks to MEAI AIContent types
PopulateResultContent(resultContent, result);
Expand Down Expand Up @@ -183,6 +191,35 @@ private static string ComputeHeadersHash(IDictionary<string, string>? headers)
return hashCode.ToString(CultureInfo.InvariantCulture);
}

internal static bool TryClassifyToolInvocationFailure(string? message, out string? failureCode)
{
if (string.IsNullOrWhiteSpace(message))
{
failureCode = null;
return false;
}

string normalized = message.ToLowerInvariant();
if (normalized.Contains("tool not found", StringComparison.Ordinal) ||
normalized.Contains("unknown tool", StringComparison.Ordinal) ||
normalized.Contains("no tool named", StringComparison.Ordinal))
{
failureCode = "mcp_tool_missing";
return true;
}

if (normalized.Contains("invalid params", StringComparison.Ordinal) ||
normalized.Contains("schema", StringComparison.Ordinal) ||
normalized.Contains("validation", StringComparison.Ordinal))
{
failureCode = "mcp_tool_schema_mismatch";
return true;
}

failureCode = null;
return false;
}

private static void PopulateResultContent(McpServerToolResultContent resultContent, CallToolResult result)
{
// Ensure Output list is initialized
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,5 +488,28 @@ public void ConvertContentBlock_AudioContentBlock_WithNullMimeType_ShouldDefault
dataContent.MediaType.Should().Be("audio/*");
}

[Theory]
[InlineData("Tool not found on remote server", "mcp_tool_missing")]
[InlineData("Invalid params: schema changed", "mcp_tool_schema_mismatch")]
[InlineData("Request failed validation", "mcp_tool_schema_mismatch")]
public void TryClassifyToolInvocationFailure_WithKnownSchemaOrToolMessages_ReturnsStableCode(
string message,
string expectedCode)
{
bool classified = DefaultMcpToolHandler.TryClassifyToolInvocationFailure(message, out string? code);

classified.Should().BeTrue();
code.Should().Be(expectedCode);
}

[Fact]
public void TryClassifyToolInvocationFailure_WithUnrelatedMessage_ReturnsFalse()
{
bool classified = DefaultMcpToolHandler.TryClassifyToolInvocationFailure("Socket closed unexpectedly", out string? code);

classified.Should().BeFalse();
code.Should().BeNull();
}

#endregion
}
26 changes: 26 additions & 0 deletions python/packages/core/agent_framework/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,17 @@ def _parse_content_from_mcp(
return return_types


def _classify_mcp_tool_failure(message: str) -> str | None:
lowered = message.lower()

if "tool not found" in lowered or "unknown tool" in lowered or "no tool named" in lowered:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels very brittle, and is likely to get out of date as the library evolves, is there no better way to understand what type of McpError is raised (subclass or code)?

return "mcp_tool_missing"
if "invalid params" in lowered or "schema" in lowered or "validation" in lowered:
return "mcp_tool_schema_mismatch"

return None


def _prepare_content_for_mcp(
content: Content,
) -> types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink | None:
Expand Down Expand Up @@ -637,6 +648,9 @@ async def _connect_on_owner(self, *, reset: bool = False) -> None:
self.session = None
self.is_connected = False
self._exit_stack = AsyncExitStack()
self._functions = []
self._tools_loaded = False
self._prompts_loaded = False
if not self.session:
try:
transport = await self._exit_stack.enter_async_context(self.get_mcp_client())
Expand Down Expand Up @@ -1054,6 +1068,18 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
inner_exception=cl_ex,
) from cl_ex
except McpError as mcp_exc:
failure_code = _classify_mcp_tool_failure(mcp_exc.error.message)
if failure_code is not None:
try:
await self.connect(reset=True)
except Exception:
logger.debug(
"Failed to refresh MCP tool definitions after classified tool failure.", exc_info=True
)
raise ToolExecutionException(
f"[{failure_code}] {mcp_exc.error.message}",
inner_exception=mcp_exc,
) from mcp_exc
raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc
except Exception as ex:
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
Expand Down
50 changes: 50 additions & 0 deletions python/packages/core/tests/core/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from agent_framework._mcp import (
MCPTool,
_classify_mcp_tool_failure,
_get_input_model_from_mcp_prompt,
_normalize_mcp_name,
_parse_content_from_mcp,
Expand Down Expand Up @@ -53,6 +54,12 @@ def test_normalize_mcp_name():
assert _normalize_mcp_name("name/with\\slashes") == "name-with-slashes"


def test_classify_mcp_tool_failure():
assert _classify_mcp_tool_failure("Tool not found on remote server") == "mcp_tool_missing"
assert _classify_mcp_tool_failure("Invalid params for schema validation") == "mcp_tool_schema_mismatch"
assert _classify_mcp_tool_failure("transport closed") is None


def test_mcp_transport_subclasses_accept_tool_name_prefix() -> None:
assert MCPStdioTool(name="stdio", command="python", tool_name_prefix="stdio").tool_name_prefix == "stdio"
assert (
Expand Down Expand Up @@ -1032,6 +1039,49 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
await func.invoke(param="test_value")


async def test_local_mcp_server_schema_drift_error_is_classified_and_refreshes():
"""Schema drift should fail closed with a stable marker and trigger a tool refresh."""

class TestServer(MCPTool):
async def connect(self):
self.session = Mock(spec=ClientSession)
self.session.list_tools = AsyncMock(
return_value=types.ListToolsResult(
tools=[
types.Tool(
name="test_tool",
description="Test tool",
inputSchema={
"type": "object",
"properties": {"param": {"type": "string"}},
"required": ["param"],
},
)
]
)
)
self.session.call_tool = AsyncMock(
side_effect=McpError(types.ErrorData(code=-32602, message="Invalid params: schema changed"))
)

def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
return None

server = TestServer(name="test_server")
async with server:
await server.load_tools()
func = server.functions[0]

with (
patch.object(server, "connect", new_callable=AsyncMock) as mock_connect,
pytest.raises(ToolExecutionException, match=r"\[mcp_tool_schema_mismatch\]") as exc_info,
):
await func.invoke(param="test_value")

mock_connect.assert_awaited_once_with(reset=True)
assert "schema changed" in str(exc_info.value)


async def test_mcp_tool_call_tool_raises_on_is_error():
"""Test that call_tool raises ToolExecutionException when MCP returns isError=True."""

Expand Down