Skip to content

Commit e417b74

Browse files
committed
fix sampling in streamable http
1 parent ed25167 commit e417b74

File tree

5 files changed

+145
-19
lines changed

5 files changed

+145
-19
lines changed

src/mcp/client/streamable_http.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import anyio
1717
import httpx
18+
from anyio.abc import TaskGroup
1819
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1920
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
2021

@@ -239,7 +240,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
239240
break
240241

241242
async def _handle_post_request(self, ctx: RequestContext) -> None:
242-
"""Handle a POST request with response processing."""
243+
"""Handle a POST request with response processing."""
243244
headers = self._update_headers_with_session(ctx.headers)
244245
message = ctx.session_message.message
245246
is_initialization = self._is_initialization_request(message)
@@ -300,7 +301,7 @@ async def _handle_sse_response(
300301
try:
301302
event_source = EventSource(response)
302303
async for sse in event_source.aiter_sse():
303-
await self._handle_sse_event(
304+
is_complete = await self._handle_sse_event(
304305
sse,
305306
ctx.read_stream_writer,
306307
resumption_callback=(
@@ -309,6 +310,8 @@ async def _handle_sse_response(
309310
else None
310311
),
311312
)
313+
if is_complete:
314+
break
312315
except Exception as e:
313316
logger.exception("Error reading SSE stream:")
314317
await ctx.read_stream_writer.send(e)
@@ -344,6 +347,7 @@ async def post_writer(
344347
read_stream_writer: StreamWriter,
345348
write_stream: MemoryObjectSendStream[SessionMessage],
346349
start_get_stream: Callable[[], None],
350+
tg: TaskGroup,
347351
) -> None:
348352
"""Handle writing requests to the server."""
349353
try:
@@ -375,10 +379,17 @@ async def post_writer(
375379
sse_read_timeout=self.sse_read_timeout,
376380
)
377381

378-
if is_resumption:
379-
await self._handle_resumption_request(ctx)
382+
async def handle_request_async():
383+
if is_resumption:
384+
await self._handle_resumption_request(ctx)
385+
else:
386+
await self._handle_post_request(ctx)
387+
388+
# If this is a request, start a new task to handle it
389+
if isinstance(message.root, JSONRPCRequest):
390+
tg.start_soon(handle_request_async)
380391
else:
381-
await self._handle_post_request(ctx)
392+
await handle_request_async()
382393

383394
except Exception as exc:
384395
logger.error(f"Error in post_writer: {exc}")
@@ -466,6 +477,7 @@ def start_get_stream() -> None:
466477
read_stream_writer,
467478
write_stream,
468479
start_get_stream,
480+
tg,
469481
)
470482

471483
try:

src/mcp/server/session.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
4747

4848
import mcp.types as types
4949
from mcp.server.models import InitializationOptions
50-
from mcp.shared.message import SessionMessage
50+
from mcp.shared.message import SessionMessage, ServerMessageMetadata
5151
from mcp.shared.session import (
5252
BaseSession,
5353
RequestResponder,
@@ -230,10 +230,11 @@ async def create_message(
230230
stop_sequences: list[str] | None = None,
231231
metadata: dict[str, Any] | None = None,
232232
model_preferences: types.ModelPreferences | None = None,
233+
related_request_id: types.RequestId | None = None,
233234
) -> types.CreateMessageResult:
234235
"""Send a sampling/create_message request."""
235236
return await self.send_request(
236-
types.ServerRequest(
237+
request=types.ServerRequest(
237238
types.CreateMessageRequest(
238239
method="sampling/createMessage",
239240
params=types.CreateMessageRequestParams(
@@ -248,7 +249,10 @@ async def create_message(
248249
),
249250
)
250251
),
251-
types.CreateMessageResult,
252+
result_type=types.CreateMessageResult,
253+
metadata=ServerMessageMetadata(
254+
related_request_id=related_request_id,
255+
),
252256
)
253257

254258
async def list_roots(self) -> types.ListRootsResult:

src/mcp/server/streamable_http.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -849,9 +849,15 @@ async def message_router():
849849
# Determine which request stream(s) should receive this message
850850
message = session_message.message
851851
target_request_id = None
852-
if isinstance(
853-
message.root, JSONRPCNotification | JSONRPCRequest
854-
):
852+
# Check if this is a response
853+
if isinstance(message.root, JSONRPCResponse | JSONRPCError):
854+
response_id = str(message.root.id)
855+
# If this response is for an existing request stream,
856+
# send it there
857+
if response_id in self._request_streams:
858+
target_request_id = response_id
859+
860+
else:
855861
# Extract related_request_id from meta if it exists
856862
if (
857863
session_message.metadata is not None
@@ -865,10 +871,12 @@ async def message_router():
865871
target_request_id = str(
866872
session_message.metadata.related_request_id
867873
)
868-
else:
869-
target_request_id = str(message.root.id)
870874

871-
request_stream_id = target_request_id or GET_STREAM_KEY
875+
request_stream_id = (
876+
target_request_id
877+
if target_request_id is not None
878+
else GET_STREAM_KEY
879+
)
872880

873881
# Store the event if we have an event store,
874882
# regardless of whether a client is connected

src/mcp/shared/session.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ async def send_request(
223223
Do not use this method to emit notifications! Use send_notification()
224224
instead.
225225
"""
226-
227226
request_id = self._request_id
228227
self._request_id = request_id + 1
229228

tests/shared/test_streamable_http.py

Lines changed: 107 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import socket
99
import time
1010
from collections.abc import Generator
11+
from typing import Any
1112

1213
import anyio
1314
import httpx
@@ -33,6 +34,7 @@
3334
StreamId,
3435
)
3536
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
37+
from mcp.shared.context import RequestContext
3638
from mcp.shared.exceptions import McpError
3739
from mcp.shared.message import (
3840
ClientMessageMetadata,
@@ -139,6 +141,11 @@ async def handle_list_tools() -> list[Tool]:
139141
description="A long-running tool that sends periodic notifications",
140142
inputSchema={"type": "object", "properties": {}},
141143
),
144+
Tool(
145+
name="test_sampling_tool",
146+
description="A tool that triggers server-side sampling",
147+
inputSchema={"type": "object", "properties": {}},
148+
),
142149
]
143150

144151
@self.call_tool()
@@ -174,6 +181,34 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
174181

175182
return [TextContent(type="text", text="Completed!")]
176183

184+
elif name == "test_sampling_tool":
185+
# Test sampling by requesting the client to sample a message
186+
sampling_result = await ctx.session.create_message(
187+
messages=[
188+
types.SamplingMessage(
189+
role="user",
190+
content=types.TextContent(
191+
type="text", text="Server needs client sampling"
192+
),
193+
)
194+
],
195+
max_tokens=100,
196+
related_request_id=ctx.request_id,
197+
)
198+
199+
# Return the sampling result in the tool response
200+
response = (
201+
sampling_result.content.text
202+
if sampling_result.content.type == "text"
203+
else None
204+
)
205+
return [
206+
TextContent(
207+
type="text",
208+
text=f"Response from sampling: {response}",
209+
)
210+
]
211+
177212
return [TextContent(type="text", text=f"Called {name}")]
178213

179214

@@ -754,7 +789,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
754789
"""Test client tool invocation."""
755790
# First list tools
756791
tools = await initialized_client_session.list_tools()
757-
assert len(tools.tools) == 3
792+
assert len(tools.tools) == 4
758793
assert tools.tools[0].name == "test_tool"
759794

760795
# Call the tool
@@ -795,7 +830,7 @@ async def test_streamablehttp_client_session_persistence(
795830

796831
# Make multiple requests to verify session persistence
797832
tools = await session.list_tools()
798-
assert len(tools.tools) == 3
833+
assert len(tools.tools) == 4
799834

800835
# Read a resource
801836
resource = await session.read_resource(uri=AnyUrl("foobar://test-persist"))
@@ -826,7 +861,7 @@ async def test_streamablehttp_client_json_response(
826861

827862
# Check tool listing
828863
tools = await session.list_tools()
829-
assert len(tools.tools) == 3
864+
assert len(tools.tools) == 4
830865

831866
# Call a tool and verify JSON response handling
832867
result = await session.call_tool("test_tool", {})
@@ -905,7 +940,7 @@ async def test_streamablehttp_client_session_termination(
905940

906941
# Make a request to confirm session is working
907942
tools = await session.list_tools()
908-
assert len(tools.tools) == 3
943+
assert len(tools.tools) == 4
909944

910945
headers = {}
911946
if captured_session_id:
@@ -1054,3 +1089,71 @@ async def run_tool():
10541089
assert not any(
10551090
n in captured_notifications_pre for n in captured_notifications
10561091
)
1092+
1093+
1094+
@pytest.mark.anyio
1095+
async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
1096+
"""Test server-initiated sampling request through streamable HTTP transport."""
1097+
print("Testing server sampling...")
1098+
# Variable to track if sampling callback was invoked
1099+
sampling_callback_invoked = False
1100+
captured_message_params = None
1101+
1102+
# Define sampling callback that returns a mock response
1103+
async def sampling_callback(
1104+
context: RequestContext[ClientSession, Any],
1105+
params: types.CreateMessageRequestParams,
1106+
) -> types.CreateMessageResult:
1107+
nonlocal sampling_callback_invoked, captured_message_params
1108+
sampling_callback_invoked = True
1109+
captured_message_params = params
1110+
message_received = (
1111+
params.messages[0].content.text
1112+
if params.messages[0].content.type == "text"
1113+
else None
1114+
)
1115+
1116+
return types.CreateMessageResult(
1117+
role="assistant",
1118+
content=types.TextContent(
1119+
type="text",
1120+
text=f"Received message from server: {message_received}",
1121+
),
1122+
model="test-model",
1123+
stopReason="endTurn",
1124+
)
1125+
1126+
# Create client with sampling callback
1127+
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
1128+
read_stream,
1129+
write_stream,
1130+
_,
1131+
):
1132+
async with ClientSession(
1133+
read_stream,
1134+
write_stream,
1135+
sampling_callback=sampling_callback,
1136+
) as session:
1137+
# Initialize the session
1138+
result = await session.initialize()
1139+
assert isinstance(result, InitializeResult)
1140+
1141+
# Call the tool that triggers server-side sampling
1142+
tool_result = await session.call_tool("test_sampling_tool", {})
1143+
1144+
# Verify the tool result contains the expected content
1145+
assert len(tool_result.content) == 1
1146+
assert tool_result.content[0].type == "text"
1147+
assert (
1148+
"Response from sampling: Received message from server"
1149+
in tool_result.content[0].text
1150+
)
1151+
1152+
# Verify sampling callback was invoked
1153+
assert sampling_callback_invoked
1154+
assert captured_message_params is not None
1155+
assert len(captured_message_params.messages) == 1
1156+
assert (
1157+
captured_message_params.messages[0].content.text
1158+
== "Server needs client sampling"
1159+
)

0 commit comments

Comments
 (0)