Skip to content

Commit c24d59e

Browse files
committed
fix flaky test
1 parent a99711d commit c24d59e

File tree

2 files changed

+97
-58
lines changed

2 files changed

+97
-58
lines changed

src/mcp/server/streamable_http.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -837,9 +837,7 @@ async def message_router():
837837
response_id = str(message.root.id)
838838
# If this response is for an existing request stream,
839839
# send it there
840-
if response_id in self._request_streams:
841-
target_request_id = response_id
842-
840+
target_request_id = response_id
843841
else:
844842
# Extract related_request_id from meta if it exists
845843
if (

tests/shared/test_streamable_http.py

Lines changed: 96 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -98,32 +98,36 @@ async def replay_events_after(
9898
send_callback: EventCallback,
9999
) -> StreamId | None:
100100
"""Replay events after the specified ID."""
101-
# Find the index of the last event ID
102-
start_index = None
103-
for i, (_, event_id, _) in enumerate(self._events):
101+
# Find the stream ID of the last event
102+
target_stream_id = None
103+
for stream_id, event_id, _ in self._events:
104104
if event_id == last_event_id:
105-
start_index = i + 1
105+
target_stream_id = stream_id
106106
break
107107

108-
if start_index is None:
109-
# If event ID not found, start from beginning
110-
start_index = 0
108+
if target_stream_id is None:
109+
# If event ID not found, return None
110+
return None
111111

112-
stream_id = None
113-
# Replay events
114-
for _, event_id, message in self._events[start_index:]:
115-
await send_callback(EventMessage(message, event_id))
116-
# Capture the stream ID from the first replayed event
117-
if stream_id is None and len(self._events) > start_index:
118-
stream_id = self._events[start_index][0]
112+
# Convert last_event_id to int for comparison
113+
last_event_id_int = int(last_event_id)
119114

120-
return stream_id
115+
# Replay only events from the same stream with ID > last_event_id
116+
for stream_id, event_id, message in self._events:
117+
if stream_id == target_stream_id and int(event_id) > last_event_id_int:
118+
await send_callback(EventMessage(message, event_id))
119+
120+
return target_stream_id
121121

122122

123123
# Test server implementation that follows MCP protocol
124124
class ServerTest(Server):
125125
def __init__(self):
126126
super().__init__(SERVER_NAME)
127+
self._lock = anyio.Event()
128+
# Reset the lock for each new server instance
129+
self._lock.set()
130+
self._lock = anyio.Event()
127131

128132
@self.read_resource()
129133
async def handle_read_resource(uri: AnyUrl) -> str | bytes:
@@ -159,6 +163,16 @@ async def handle_list_tools() -> list[Tool]:
159163
description="A tool that triggers server-side sampling",
160164
inputSchema={"type": "object", "properties": {}},
161165
),
166+
Tool(
167+
name="wait_for_lock_with_notification",
168+
description="A tool that sends a notification and waits for lock",
169+
inputSchema={"type": "object", "properties": {}},
170+
),
171+
Tool(
172+
name="release_lock",
173+
description="A tool that releases the lock",
174+
inputSchema={"type": "object", "properties": {}},
175+
),
162176
]
163177

164178
@self.call_tool()
@@ -214,6 +228,33 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
214228
)
215229
]
216230

231+
elif name == "wait_for_lock_with_notification":
232+
# First send a notification
233+
await ctx.session.send_log_message(
234+
level="info",
235+
data="First notification before lock",
236+
logger="lock_tool",
237+
related_request_id=ctx.request_id,
238+
)
239+
240+
# Now wait for the lock to be released
241+
await self._lock.wait()
242+
243+
# Send second notification after lock is released
244+
await ctx.session.send_log_message(
245+
level="info",
246+
data="Second notification after lock",
247+
logger="lock_tool",
248+
related_request_id=ctx.request_id,
249+
)
250+
251+
return [TextContent(type="text", text="Completed")]
252+
253+
elif name == "release_lock":
254+
# Release the lock
255+
self._lock.set()
256+
return [TextContent(type="text", text="Lock released")]
257+
217258
return [TextContent(type="text", text=f"Called {name}")]
218259

219260

@@ -825,7 +866,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
825866
"""Test client tool invocation."""
826867
# First list tools
827868
tools = await initialized_client_session.list_tools()
828-
assert len(tools.tools) == 4
869+
assert len(tools.tools) == 6
829870
assert tools.tools[0].name == "test_tool"
830871

831872
# Call the tool
@@ -862,7 +903,7 @@ async def test_streamablehttp_client_session_persistence(basic_server, basic_ser
862903

863904
# Make multiple requests to verify session persistence
864905
tools = await session.list_tools()
865-
assert len(tools.tools) == 4
906+
assert len(tools.tools) == 6
866907

867908
# Read a resource
868909
resource = await session.read_resource(uri=AnyUrl("foobar://test-persist"))
@@ -891,7 +932,7 @@ async def test_streamablehttp_client_json_response(json_response_server, json_se
891932

892933
# Check tool listing
893934
tools = await session.list_tools()
894-
assert len(tools.tools) == 4
935+
assert len(tools.tools) == 6
895936

896937
# Call a tool and verify JSON response handling
897938
result = await session.call_tool("test_tool", {})
@@ -962,7 +1003,7 @@ async def test_streamablehttp_client_session_termination(basic_server, basic_ser
9621003

9631004
# Make a request to confirm session is working
9641005
tools = await session.list_tools()
965-
assert len(tools.tools) == 4
1006+
assert len(tools.tools) == 6
9661007

9671008
headers = {}
9681009
if captured_session_id:
@@ -1026,7 +1067,7 @@ async def mock_delete(self, *args, **kwargs):
10261067

10271068
# Make a request to confirm session is working
10281069
tools = await session.list_tools()
1029-
assert len(tools.tools) == 4
1070+
assert len(tools.tools) == 6
10301071

10311072
headers = {}
10321073
if captured_session_id:
@@ -1048,32 +1089,32 @@ async def mock_delete(self, *args, **kwargs):
10481089

10491090
@pytest.mark.anyio
10501091
async def test_streamablehttp_client_resumption(event_server):
1051-
"""Test client session to resume a long running tool."""
1092+
"""Test client session resumption using sync primitives for reliable coordination."""
10521093
_, server_url = event_server
10531094

10541095
# Variables to track the state
10551096
captured_session_id = None
10561097
captured_resumption_token = None
10571098
captured_notifications = []
1058-
tool_started = False
10591099
captured_protocol_version = None
1100+
first_notification_received = False
10601101

10611102
async def message_handler(
10621103
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
10631104
) -> None:
10641105
if isinstance(message, types.ServerNotification):
10651106
captured_notifications.append(message)
1066-
# Look for our special notification that indicates the tool is running
1107+
# Look for our first notification
10671108
if isinstance(message.root, types.LoggingMessageNotification):
1068-
if message.root.params.data == "Tool started":
1069-
nonlocal tool_started
1070-
tool_started = True
1109+
if message.root.params.data == "First notification before lock":
1110+
nonlocal first_notification_received
1111+
first_notification_received = True
10711112

10721113
async def on_resumption_token_update(token: str) -> None:
10731114
nonlocal captured_resumption_token
10741115
captured_resumption_token = token
10751116

1076-
# First, start the client session and begin the long-running tool
1117+
# First, start the client session and begin the tool that waits on lock
10771118
async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as (
10781119
read_stream,
10791120
write_stream,
@@ -1088,7 +1129,7 @@ async def on_resumption_token_update(token: str) -> None:
10881129
# Capture the negotiated protocol version
10891130
captured_protocol_version = result.protocolVersion
10901131

1091-
# Start a long-running tool in a task
1132+
# Start the tool that will wait on lock in a task
10921133
async with anyio.create_task_group() as tg:
10931134

10941135
async def run_tool():
@@ -1099,7 +1140,9 @@ async def run_tool():
10991140
types.ClientRequest(
11001141
types.CallToolRequest(
11011142
method="tools/call",
1102-
params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}),
1143+
params=types.CallToolRequestParams(
1144+
name="wait_for_lock_with_notification", arguments={}
1145+
),
11031146
)
11041147
),
11051148
types.CallToolResult,
@@ -1108,15 +1151,19 @@ async def run_tool():
11081151

11091152
tg.start_soon(run_tool)
11101153

1111-
# Wait for the tool to start and at least one notification
1112-
# and then kill the task group
1113-
while not tool_started or not captured_resumption_token:
1154+
# Wait for the first notification and resumption token
1155+
while not first_notification_received or not captured_resumption_token:
11141156
await anyio.sleep(0.1)
1157+
1158+
# Kill the client session while tool is waiting on lock
11151159
tg.cancel_scope.cancel()
11161160

1117-
# Store pre notifications and clear the captured notifications
1118-
# for the post-resumption check
1119-
captured_notifications_pre = captured_notifications.copy()
1161+
# Verify we received exactly one notification
1162+
assert len(captured_notifications) == 1
1163+
assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification)
1164+
assert captured_notifications[0].root.params.data == "First notification before lock"
1165+
1166+
# Clear notifications for the second phase
11201167
captured_notifications = []
11211168

11221169
# Now resume the session with the same mcp-session-id and protocol version
@@ -1125,54 +1172,48 @@ async def run_tool():
11251172
headers[MCP_SESSION_ID_HEADER] = captured_session_id
11261173
if captured_protocol_version:
11271174
headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version
1128-
11291175
async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as (
11301176
read_stream,
11311177
write_stream,
11321178
_,
11331179
):
11341180
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
1135-
# Don't initialize - just use the existing session
1136-
1137-
# Resume the tool with the resumption token
1138-
assert captured_resumption_token is not None
1139-
1181+
result = await session.send_request(
1182+
types.ClientRequest(
1183+
types.CallToolRequest(
1184+
method="tools/call",
1185+
params=types.CallToolRequestParams(name="release_lock", arguments={}),
1186+
)
1187+
),
1188+
types.CallToolResult,
1189+
)
11401190
metadata = ClientMessageMetadata(
11411191
resumption_token=captured_resumption_token,
11421192
)
1193+
11431194
result = await session.send_request(
11441195
types.ClientRequest(
11451196
types.CallToolRequest(
11461197
method="tools/call",
1147-
params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}),
1198+
params=types.CallToolRequestParams(name="wait_for_lock_with_notification", arguments={}),
11481199
)
11491200
),
11501201
types.CallToolResult,
11511202
metadata=metadata,
11521203
)
1153-
1154-
# We should get a complete result
11551204
assert len(result.content) == 1
11561205
assert result.content[0].type == "text"
1157-
assert "Completed" in result.content[0].text
1206+
assert result.content[0].text == "Completed"
11581207

11591208
# We should have received the remaining notifications
1160-
assert len(captured_notifications) > 0
1209+
assert len(captured_notifications) == 1
11611210

1162-
# Should not have the first notification
1163-
# Check that "Tool started" notification isn't repeated when resuming
1164-
assert not any(
1165-
isinstance(n.root, types.LoggingMessageNotification) and n.root.params.data == "Tool started"
1166-
for n in captured_notifications
1167-
)
1168-
# there is no intersection between pre and post notifications
1169-
assert not any(n in captured_notifications_pre for n in captured_notifications)
1211+
assert captured_notifications[0].root.params.data == "Second notification after lock"
11701212

11711213

11721214
@pytest.mark.anyio
11731215
async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
11741216
"""Test server-initiated sampling request through streamable HTTP transport."""
1175-
print("Testing server sampling...")
11761217
# Variable to track if sampling callback was invoked
11771218
sampling_callback_invoked = False
11781219
captured_message_params = None

0 commit comments

Comments
 (0)