@@ -98,32 +98,36 @@ async def replay_events_after(
9898 send_callback : EventCallback ,
9999 ) -> StreamId | None :
100100 """Replay events after the specified ID."""
101- # Find the index of the last event ID
102- start_index = None
103- for i , ( _ , event_id , _ ) in enumerate ( self ._events ) :
101+ # Find the stream ID of the last event
102+ target_stream_id = None
103+ for stream_id , event_id , _ in self ._events :
104104 if event_id == last_event_id :
105- start_index = i + 1
105+ target_stream_id = stream_id
106106 break
107107
108- if start_index is None :
109- # If event ID not found, start from beginning
110- start_index = 0
108+ if target_stream_id is None :
109+ # If event ID not found, return None
110+ return None
111111
112- stream_id = None
113- # Replay events
114- for _ , event_id , message in self ._events [start_index :]:
115- await send_callback (EventMessage (message , event_id ))
116- # Capture the stream ID from the first replayed event
117- if stream_id is None and len (self ._events ) > start_index :
118- stream_id = self ._events [start_index ][0 ]
112+ # Convert last_event_id to int for comparison
113+ last_event_id_int = int (last_event_id )
119114
120- return stream_id
115+ # Replay only events from the same stream with ID > last_event_id
116+ for stream_id , event_id , message in self ._events :
117+ if stream_id == target_stream_id and int (event_id ) > last_event_id_int :
118+ await send_callback (EventMessage (message , event_id ))
119+
120+ return target_stream_id
121121
122122
123123# Test server implementation that follows MCP protocol
124124class ServerTest (Server ):
125125 def __init__ (self ):
126126 super ().__init__ (SERVER_NAME )
127+ self ._lock = anyio .Event ()
128+ # Reset the lock for each new server instance
129+ self ._lock .set ()
130+ self ._lock = anyio .Event ()
127131
128132 @self .read_resource ()
129133 async def handle_read_resource (uri : AnyUrl ) -> str | bytes :
@@ -159,6 +163,16 @@ async def handle_list_tools() -> list[Tool]:
159163 description = "A tool that triggers server-side sampling" ,
160164 inputSchema = {"type" : "object" , "properties" : {}},
161165 ),
166+ Tool (
167+ name = "wait_for_lock_with_notification" ,
168+ description = "A tool that sends a notification and waits for lock" ,
169+ inputSchema = {"type" : "object" , "properties" : {}},
170+ ),
171+ Tool (
172+ name = "release_lock" ,
173+ description = "A tool that releases the lock" ,
174+ inputSchema = {"type" : "object" , "properties" : {}},
175+ ),
162176 ]
163177
164178 @self .call_tool ()
@@ -214,6 +228,33 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
214228 )
215229 ]
216230
231+ elif name == "wait_for_lock_with_notification" :
232+ # First send a notification
233+ await ctx .session .send_log_message (
234+ level = "info" ,
235+ data = "First notification before lock" ,
236+ logger = "lock_tool" ,
237+ related_request_id = ctx .request_id ,
238+ )
239+
240+ # Now wait for the lock to be released
241+ await self ._lock .wait ()
242+
243+ # Send second notification after lock is released
244+ await ctx .session .send_log_message (
245+ level = "info" ,
246+ data = "Second notification after lock" ,
247+ logger = "lock_tool" ,
248+ related_request_id = ctx .request_id ,
249+ )
250+
251+ return [TextContent (type = "text" , text = "Completed" )]
252+
253+ elif name == "release_lock" :
254+ # Release the lock
255+ self ._lock .set ()
256+ return [TextContent (type = "text" , text = "Lock released" )]
257+
217258 return [TextContent (type = "text" , text = f"Called { name } " )]
218259
219260
@@ -825,7 +866,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
825866 """Test client tool invocation."""
826867 # First list tools
827868 tools = await initialized_client_session .list_tools ()
828- assert len (tools .tools ) == 4
869+ assert len (tools .tools ) == 6
829870 assert tools .tools [0 ].name == "test_tool"
830871
831872 # Call the tool
@@ -862,7 +903,7 @@ async def test_streamablehttp_client_session_persistence(basic_server, basic_ser
862903
863904 # Make multiple requests to verify session persistence
864905 tools = await session .list_tools ()
865- assert len (tools .tools ) == 4
906+ assert len (tools .tools ) == 6
866907
867908 # Read a resource
868909 resource = await session .read_resource (uri = AnyUrl ("foobar://test-persist" ))
@@ -891,7 +932,7 @@ async def test_streamablehttp_client_json_response(json_response_server, json_se
891932
892933 # Check tool listing
893934 tools = await session .list_tools ()
894- assert len (tools .tools ) == 4
935+ assert len (tools .tools ) == 6
895936
896937 # Call a tool and verify JSON response handling
897938 result = await session .call_tool ("test_tool" , {})
@@ -962,7 +1003,7 @@ async def test_streamablehttp_client_session_termination(basic_server, basic_ser
9621003
9631004 # Make a request to confirm session is working
9641005 tools = await session .list_tools ()
965- assert len (tools .tools ) == 4
1006+ assert len (tools .tools ) == 6
9661007
9671008 headers = {}
9681009 if captured_session_id :
@@ -1026,7 +1067,7 @@ async def mock_delete(self, *args, **kwargs):
10261067
10271068 # Make a request to confirm session is working
10281069 tools = await session .list_tools ()
1029- assert len (tools .tools ) == 4
1070+ assert len (tools .tools ) == 6
10301071
10311072 headers = {}
10321073 if captured_session_id :
@@ -1048,32 +1089,32 @@ async def mock_delete(self, *args, **kwargs):
10481089
10491090@pytest .mark .anyio
10501091async def test_streamablehttp_client_resumption (event_server ):
1051- """Test client session to resume a long running tool ."""
1092+ """Test client session resumption using sync primitives for reliable coordination ."""
10521093 _ , server_url = event_server
10531094
10541095 # Variables to track the state
10551096 captured_session_id = None
10561097 captured_resumption_token = None
10571098 captured_notifications = []
1058- tool_started = False
10591099 captured_protocol_version = None
1100+ first_notification_received = False
10601101
10611102 async def message_handler (
10621103 message : RequestResponder [types .ServerRequest , types .ClientResult ] | types .ServerNotification | Exception ,
10631104 ) -> None :
10641105 if isinstance (message , types .ServerNotification ):
10651106 captured_notifications .append (message )
1066- # Look for our special notification that indicates the tool is running
1107+ # Look for our first notification
10671108 if isinstance (message .root , types .LoggingMessageNotification ):
1068- if message .root .params .data == "Tool started " :
1069- nonlocal tool_started
1070- tool_started = True
1109+ if message .root .params .data == "First notification before lock " :
1110+ nonlocal first_notification_received
1111+ first_notification_received = True
10711112
10721113 async def on_resumption_token_update (token : str ) -> None :
10731114 nonlocal captured_resumption_token
10741115 captured_resumption_token = token
10751116
1076- # First, start the client session and begin the long-running tool
1117+ # First, start the client session and begin the tool that waits on lock
10771118 async with streamablehttp_client (f"{ server_url } /mcp" , terminate_on_close = False ) as (
10781119 read_stream ,
10791120 write_stream ,
@@ -1088,7 +1129,7 @@ async def on_resumption_token_update(token: str) -> None:
10881129 # Capture the negotiated protocol version
10891130 captured_protocol_version = result .protocolVersion
10901131
1091- # Start a long-running tool in a task
1132+ # Start the tool that will wait on lock in a task
10921133 async with anyio .create_task_group () as tg :
10931134
10941135 async def run_tool ():
@@ -1099,7 +1140,9 @@ async def run_tool():
10991140 types .ClientRequest (
11001141 types .CallToolRequest (
11011142 method = "tools/call" ,
1102- params = types .CallToolRequestParams (name = "long_running_with_checkpoints" , arguments = {}),
1143+ params = types .CallToolRequestParams (
1144+ name = "wait_for_lock_with_notification" , arguments = {}
1145+ ),
11031146 )
11041147 ),
11051148 types .CallToolResult ,
@@ -1108,15 +1151,19 @@ async def run_tool():
11081151
11091152 tg .start_soon (run_tool )
11101153
1111- # Wait for the tool to start and at least one notification
1112- # and then kill the task group
1113- while not tool_started or not captured_resumption_token :
1154+ # Wait for the first notification and resumption token
1155+ while not first_notification_received or not captured_resumption_token :
11141156 await anyio .sleep (0.1 )
1157+
1158+ # Kill the client session while tool is waiting on lock
11151159 tg .cancel_scope .cancel ()
11161160
1117- # Store pre notifications and clear the captured notifications
1118- # for the post-resumption check
1119- captured_notifications_pre = captured_notifications .copy ()
1161+ # Verify we received exactly one notification
1162+ assert len (captured_notifications ) == 1
1163+ assert isinstance (captured_notifications [0 ].root , types .LoggingMessageNotification )
1164+ assert captured_notifications [0 ].root .params .data == "First notification before lock"
1165+
1166+ # Clear notifications for the second phase
11201167 captured_notifications = []
11211168
11221169 # Now resume the session with the same mcp-session-id and protocol version
@@ -1125,54 +1172,48 @@ async def run_tool():
11251172 headers [MCP_SESSION_ID_HEADER ] = captured_session_id
11261173 if captured_protocol_version :
11271174 headers [MCP_PROTOCOL_VERSION_HEADER ] = captured_protocol_version
1128-
11291175 async with streamablehttp_client (f"{ server_url } /mcp" , headers = headers ) as (
11301176 read_stream ,
11311177 write_stream ,
11321178 _ ,
11331179 ):
11341180 async with ClientSession (read_stream , write_stream , message_handler = message_handler ) as session :
1135- # Don't initialize - just use the existing session
1136-
1137- # Resume the tool with the resumption token
1138- assert captured_resumption_token is not None
1139-
1181+ result = await session .send_request (
1182+ types .ClientRequest (
1183+ types .CallToolRequest (
1184+ method = "tools/call" ,
1185+ params = types .CallToolRequestParams (name = "release_lock" , arguments = {}),
1186+ )
1187+ ),
1188+ types .CallToolResult ,
1189+ )
11401190 metadata = ClientMessageMetadata (
11411191 resumption_token = captured_resumption_token ,
11421192 )
1193+
11431194 result = await session .send_request (
11441195 types .ClientRequest (
11451196 types .CallToolRequest (
11461197 method = "tools/call" ,
1147- params = types .CallToolRequestParams (name = "long_running_with_checkpoints " , arguments = {}),
1198+ params = types .CallToolRequestParams (name = "wait_for_lock_with_notification " , arguments = {}),
11481199 )
11491200 ),
11501201 types .CallToolResult ,
11511202 metadata = metadata ,
11521203 )
1153-
1154- # We should get a complete result
11551204 assert len (result .content ) == 1
11561205 assert result .content [0 ].type == "text"
1157- assert "Completed" in result .content [0 ].text
1206+ assert result .content [0 ].text == "Completed"
11581207
11591208 # We should have received the remaining notifications
1160- assert len (captured_notifications ) > 0
1209+ assert len (captured_notifications ) == 1
11611210
1162- # Should not have the first notification
1163- # Check that "Tool started" notification isn't repeated when resuming
1164- assert not any (
1165- isinstance (n .root , types .LoggingMessageNotification ) and n .root .params .data == "Tool started"
1166- for n in captured_notifications
1167- )
1168- # there is no intersection between pre and post notifications
1169- assert not any (n in captured_notifications_pre for n in captured_notifications )
1211+ assert captured_notifications [0 ].root .params .data == "Second notification after lock"
11701212
11711213
11721214@pytest .mark .anyio
11731215async def test_streamablehttp_server_sampling (basic_server , basic_server_url ):
11741216 """Test server-initiated sampling request through streamable HTTP transport."""
1175- print ("Testing server sampling..." )
11761217 # Variable to track if sampling callback was invoked
11771218 sampling_callback_invoked = False
11781219 captured_message_params = None
0 commit comments