Skip to content

Commit eef4f8e

Browse files
committed
validation options
1 parent ef4e167 commit eef4f8e

File tree

3 files changed

+300
-8
lines changed

3 files changed

+300
-8
lines changed

src/mcp/client/session.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import anyio.lowlevel
66
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
77
from jsonschema import SchemaError, ValidationError, validate
8-
from pydantic import AnyUrl, TypeAdapter
8+
from pydantic import AnyUrl, BaseModel, Field, TypeAdapter
99

1010
import mcp.types as types
1111
from mcp.shared.context import RequestContext
@@ -18,6 +18,17 @@
1818
logger = logging.getLogger("client")
1919

2020

21+
class ValidationOptions(BaseModel):
22+
"""Options for controlling validation behavior in MCP client sessions."""
23+
24+
strict_output_validation: bool = Field(
25+
default=True,
26+
description="Whether to raise exceptions when tools don't return structured "
27+
"content as specified by their output schema. When False, validation "
28+
"errors are logged as warnings and execution continues.",
29+
)
30+
31+
2132
class SamplingFnT(Protocol):
2233
async def __call__(
2334
self,
@@ -118,6 +129,7 @@ def __init__(
118129
logging_callback: LoggingFnT | None = None,
119130
message_handler: MessageHandlerFnT | None = None,
120131
client_info: types.Implementation | None = None,
132+
validation_options: ValidationOptions | None = None,
121133
) -> None:
122134
super().__init__(
123135
read_stream,
@@ -133,6 +145,7 @@ def __init__(
133145
self._logging_callback = logging_callback or _default_logging_callback
134146
self._message_handler = message_handler or _default_message_handler
135147
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
148+
self._validation_options = validation_options or ValidationOptions()
136149

137150
async def initialize(self) -> types.InitializeResult:
138151
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
@@ -324,13 +337,27 @@ async def _validate_tool_result(self, name: str, result: types.CallToolResult) -
324337

325338
if output_schema is not None:
326339
if result.structuredContent is None:
327-
raise RuntimeError(f"Tool {name} has an output schema but did not return structured content")
328-
try:
329-
validate(result.structuredContent, output_schema)
330-
except ValidationError as e:
331-
raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}")
332-
except SchemaError as e:
333-
raise RuntimeError(f"Invalid schema for tool {name}: {e}")
340+
if self._validation_options.strict_output_validation:
341+
raise RuntimeError(f"Tool {name} has an output schema but did not return structured content")
342+
else:
343+
logger.warning(
344+
f"Tool {name} has an output schema but did not return structured content. "
345+
f"Continuing without structured content validation due to lenient validation mode."
346+
)
347+
else:
348+
try:
349+
validate(result.structuredContent, output_schema)
350+
except ValidationError as e:
351+
if self._validation_options.strict_output_validation:
352+
raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}")
353+
else:
354+
logger.warning(
355+
f"Invalid structured content returned by tool {name}: {e}. "
356+
f"Continuing due to lenient validation mode."
357+
)
358+
except SchemaError as e:
359+
# Schema errors are always raised - they indicate a problem with the schema itself
360+
raise RuntimeError(f"Invalid schema for tool {name}: {e}")
334361

335362
async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResult:
336363
"""Send a prompts/list request."""

src/mcp/shared/memory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ async def create_connected_server_and_client_session(
6161
client_info: types.Implementation | None = None,
6262
raise_exceptions: bool = False,
6363
elicitation_callback: ElicitationFnT | None = None,
64+
validation_options: Any | None = None,
6465
) -> AsyncGenerator[ClientSession, None]:
6566
"""Creates a ClientSession that is connected to a running MCP server."""
6667
async with create_client_server_memory_streams() as (
@@ -92,6 +93,7 @@ async def create_connected_server_and_client_session(
9293
message_handler=message_handler,
9394
client_info=client_info,
9495
elicitation_callback=elicitation_callback,
96+
validation_options=validation_options,
9597
) as client_session:
9698
await client_session.initialize()
9799
yield client_session
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
"""Tests for client-side validation options."""
2+
3+
import logging
4+
from unittest.mock import AsyncMock, MagicMock
5+
6+
import pytest
7+
8+
from mcp.client.session import ValidationOptions, ClientSession
9+
from mcp.types import Tool, CallToolResult, TextContent
10+
import mcp.types as types
11+
12+
13+
class TestValidationOptions:
14+
"""Test validation options for MCP client sessions."""
15+
16+
@pytest.mark.anyio
17+
async def test_strict_validation_default(self):
18+
"""Test that strict validation is enabled by default."""
19+
# Create a mock client session
20+
read_stream = MagicMock()
21+
write_stream = MagicMock()
22+
23+
client = ClientSession(read_stream, write_stream)
24+
25+
# Set up tool with output schema
26+
client._tool_output_schemas = {
27+
"test_tool": {
28+
"type": "object",
29+
"properties": {"result": {"type": "integer"}},
30+
"required": ["result"],
31+
}
32+
}
33+
34+
# Mock send_request to return a result without structured content
35+
mock_result = CallToolResult(
36+
content=[TextContent(type="text", text="This is unstructured text content")],
37+
structuredContent=None,
38+
isError=False
39+
)
40+
41+
client.send_request = AsyncMock(return_value=mock_result)
42+
43+
# Should raise by default when structured content is missing
44+
with pytest.raises(RuntimeError) as exc_info:
45+
await client.call_tool("test_tool", {})
46+
assert "has an output schema but did not return structured content" in str(exc_info.value)
47+
48+
@pytest.mark.anyio
49+
async def test_lenient_validation_missing_content(self, caplog):
50+
"""Test lenient validation when structured content is missing."""
51+
# Set logging level to capture warnings
52+
caplog.set_level(logging.WARNING)
53+
54+
# Create client with lenient validation
55+
validation_options = ValidationOptions(strict_output_validation=False)
56+
57+
read_stream = MagicMock()
58+
write_stream = MagicMock()
59+
60+
client = ClientSession(
61+
read_stream,
62+
write_stream,
63+
validation_options=validation_options
64+
)
65+
66+
# Set up tool with output schema
67+
client._tool_output_schemas = {
68+
"test_tool": {
69+
"type": "object",
70+
"properties": {"result": {"type": "integer"}},
71+
"required": ["result"],
72+
}
73+
}
74+
75+
# Mock send_request to return a result without structured content
76+
mock_result = CallToolResult(
77+
content=[TextContent(type="text", text="This is unstructured text content")],
78+
structuredContent=None,
79+
isError=False
80+
)
81+
82+
client.send_request = AsyncMock(return_value=mock_result)
83+
84+
# Should not raise with lenient validation
85+
result = await client.call_tool("test_tool", {})
86+
87+
# Should have logged a warning
88+
assert "has an output schema but did not return structured content" in caplog.text
89+
assert "Continuing without structured content validation" in caplog.text
90+
91+
# Result should still be returned
92+
assert result.isError is False
93+
assert result.structuredContent is None
94+
95+
@pytest.mark.anyio
96+
async def test_lenient_validation_invalid_content(self, caplog):
97+
"""Test lenient validation when structured content is invalid."""
98+
# Set logging level to capture warnings
99+
caplog.set_level(logging.WARNING)
100+
101+
# Create client with lenient validation
102+
validation_options = ValidationOptions(strict_output_validation=False)
103+
104+
read_stream = MagicMock()
105+
write_stream = MagicMock()
106+
107+
client = ClientSession(
108+
read_stream,
109+
write_stream,
110+
validation_options=validation_options
111+
)
112+
113+
# Set up tool with output schema
114+
client._tool_output_schemas = {
115+
"test_tool": {
116+
"type": "object",
117+
"properties": {"result": {"type": "integer"}},
118+
"required": ["result"],
119+
}
120+
}
121+
122+
# Mock send_request to return a result with invalid structured content
123+
mock_result = CallToolResult(
124+
content=[TextContent(type="text", text='{"result": "not_an_integer"}')],
125+
structuredContent={"result": "not_an_integer"}, # Invalid: string instead of integer
126+
isError=False
127+
)
128+
129+
client.send_request = AsyncMock(return_value=mock_result)
130+
131+
# Should not raise with lenient validation
132+
result = await client.call_tool("test_tool", {})
133+
134+
# Should have logged a warning
135+
assert "Invalid structured content returned by tool test_tool" in caplog.text
136+
assert "Continuing due to lenient validation mode" in caplog.text
137+
138+
# Result should still be returned with the invalid content
139+
assert result.isError is False
140+
assert result.structuredContent == {"result": "not_an_integer"}
141+
142+
@pytest.mark.anyio
143+
async def test_strict_validation_with_valid_content(self):
144+
"""Test that valid structured content passes validation."""
145+
read_stream = MagicMock()
146+
write_stream = MagicMock()
147+
148+
client = ClientSession(read_stream, write_stream)
149+
150+
# Set up tool with output schema
151+
client._tool_output_schemas = {
152+
"test_tool": {
153+
"type": "object",
154+
"properties": {"result": {"type": "integer"}},
155+
"required": ["result"],
156+
}
157+
}
158+
159+
# Mock send_request to return a result with valid structured content
160+
mock_result = CallToolResult(
161+
content=[TextContent(type="text", text='{"result": 42}')],
162+
structuredContent={"result": 42},
163+
isError=False
164+
)
165+
166+
client.send_request = AsyncMock(return_value=mock_result)
167+
168+
# Should not raise with valid content
169+
result = await client.call_tool("test_tool", {})
170+
assert result.isError is False
171+
assert result.structuredContent == {"result": 42}
172+
173+
@pytest.mark.anyio
174+
async def test_schema_errors_always_raised(self):
175+
"""Test that schema errors are always raised regardless of validation mode."""
176+
# Create client with lenient validation
177+
validation_options = ValidationOptions(strict_output_validation=False)
178+
179+
read_stream = MagicMock()
180+
write_stream = MagicMock()
181+
182+
client = ClientSession(
183+
read_stream,
184+
write_stream,
185+
validation_options=validation_options
186+
)
187+
188+
# Set up tool with invalid output schema
189+
client._tool_output_schemas = {
190+
"test_tool": "not a valid schema" # Invalid schema
191+
}
192+
193+
# Mock send_request to return a result with structured content
194+
mock_result = CallToolResult(
195+
content=[TextContent(type="text", text='{"result": 42}')],
196+
structuredContent={"result": 42},
197+
isError=False
198+
)
199+
200+
client.send_request = AsyncMock(return_value=mock_result)
201+
202+
# Should still raise for schema errors even in lenient mode
203+
with pytest.raises(RuntimeError) as exc_info:
204+
await client.call_tool("test_tool", {})
205+
assert "Invalid schema for tool test_tool" in str(exc_info.value)
206+
207+
@pytest.mark.anyio
208+
async def test_error_results_not_validated(self):
209+
"""Test that error results are not validated."""
210+
read_stream = MagicMock()
211+
write_stream = MagicMock()
212+
213+
client = ClientSession(read_stream, write_stream)
214+
215+
# Set up tool with output schema
216+
client._tool_output_schemas = {
217+
"test_tool": {
218+
"type": "object",
219+
"properties": {"result": {"type": "integer"}},
220+
"required": ["result"],
221+
}
222+
}
223+
224+
# Mock send_request to return an error result
225+
mock_result = CallToolResult(
226+
content=[TextContent(type="text", text="Tool execution failed")],
227+
structuredContent=None,
228+
isError=True # Error result
229+
)
230+
231+
client.send_request = AsyncMock(return_value=mock_result)
232+
233+
# Should not validate error results
234+
result = await client.call_tool("test_tool", {})
235+
assert result.isError is True
236+
# No exception should be raised
237+
238+
@pytest.mark.anyio
239+
async def test_tool_without_output_schema(self):
240+
"""Test that tools without output schema don't trigger validation."""
241+
read_stream = MagicMock()
242+
write_stream = MagicMock()
243+
244+
client = ClientSession(read_stream, write_stream)
245+
246+
# Tool has no output schema
247+
client._tool_output_schemas = {
248+
"test_tool": None
249+
}
250+
251+
# Mock send_request to return a result without structured content
252+
mock_result = CallToolResult(
253+
content=[TextContent(type="text", text="This is unstructured text content")],
254+
structuredContent=None,
255+
isError=False
256+
)
257+
258+
client.send_request = AsyncMock(return_value=mock_result)
259+
260+
# Should not raise when there's no output schema
261+
result = await client.call_tool("test_tool", {})
262+
assert result.isError is False
263+
assert result.structuredContent is None

0 commit comments

Comments
 (0)