@@ -68,7 +68,6 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult:
6868 mock_context_2 = Mock ()
6969
7070 mock_context_2 .session = mock_session_2
71- mock_session_2 .send_progress_notification .result = None
7271
7372 result_cache = ResultCache (max_size = 1 , max_keep_alive = 1 )
7473 async with AsyncExitStack () as stack :
@@ -123,3 +122,178 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult:
123122 message = None ,
124123 resource_uri = None ,
125124 )
125+
126+
127+ @pytest .mark .anyio
128+ async def test_async_call_keep_alive ():
129+ """Tests async call keep alive"""
130+
131+ async def test_call (call : types .CallToolRequest ) -> types .ServerResult :
132+ return types .ServerResult (
133+ types .CallToolResult (content = [types .TextContent (type = "text" , text = "test" )])
134+ )
135+
136+ async_call = types .CallToolAsyncRequest (
137+ method = "tools/async/call" , params = types .CallToolAsyncRequestParams (name = "test" )
138+ )
139+
140+ mock_session_1 = AsyncMock ()
141+ mock_context_1 = Mock ()
142+ mock_context_1 .session = mock_session_1
143+
144+ mock_session_2 = AsyncMock ()
145+ mock_context_2 = Mock ()
146+
147+ mock_context_2 .session = mock_session_2
148+
149+ result_cache = ResultCache (max_size = 1 , max_keep_alive = 10 )
150+ async with AsyncExitStack () as stack :
151+ await stack .enter_async_context (result_cache )
152+ async_call_ref = await result_cache .start_call (
153+ test_call , async_call , mock_context_1
154+ )
155+ assert async_call_ref .token is not None
156+
157+ await result_cache .session_close_hook (mock_session_1 )
158+
159+ await result_cache .join_call (
160+ req = types .JoinCallToolAsyncRequest (
161+ method = "tools/async/join" ,
162+ params = types .JoinCallToolRequestParams (
163+ token = async_call_ref .token ,
164+ _meta = types .RequestParams .Meta (progressToken = "test" ),
165+ ),
166+ ),
167+ ctx = mock_context_2 ,
168+ )
169+ assert async_call_ref .token is not None
170+ await result_cache .notification_hook (
171+ session = mock_session_1 ,
172+ notification = types .ServerNotification (
173+ types .ProgressNotification (
174+ method = "notifications/progress" ,
175+ params = types .ProgressNotificationParams (
176+ progressToken = "test" , progress = 1
177+ ),
178+ )
179+ ),
180+ )
181+
182+ result = await result_cache .get_result (
183+ types .GetToolAsyncResultRequest (
184+ method = "tools/async/get" ,
185+ params = types .GetToolAsyncResultRequestParams (
186+ token = async_call_ref .token
187+ ),
188+ )
189+ )
190+
191+ assert not result .isError , str (result )
192+ assert not result .isPending
193+ assert len (result .content ) == 1
194+ assert type (result .content [0 ]) is types .TextContent
195+ assert result .content [0 ].text == "test"
196+
197+
198+ @pytest .mark .anyio
199+ async def test_async_call_keep_alive_expired ():
200+ """Tests async call keep alive expiry"""
201+
202+ async def test_call (call : types .CallToolRequest ) -> types .ServerResult :
203+ return types .ServerResult (
204+ types .CallToolResult (content = [types .TextContent (type = "text" , text = "test" )])
205+ )
206+
207+ async_call = types .CallToolAsyncRequest (
208+ method = "tools/async/call" , params = types .CallToolAsyncRequestParams (name = "test" )
209+ )
210+
211+ mock_session_1 = AsyncMock ()
212+ mock_context_1 = Mock ()
213+ mock_context_1 .session = mock_session_1
214+
215+ mock_session_2 = AsyncMock ()
216+ mock_context_2 = Mock ()
217+ mock_context_2 .session = mock_session_2
218+
219+ mock_session_3 = AsyncMock ()
220+ mock_context_3 = Mock ()
221+ mock_context_3 .session = mock_session_3
222+
223+ time = 0.0
224+
225+ def test_timer ():
226+ return time
227+
228+ result_cache = ResultCache (max_size = 1 , max_keep_alive = 1 , timer = test_timer )
229+ async with AsyncExitStack () as stack :
230+ await stack .enter_async_context (result_cache )
231+ async_call_ref = await result_cache .start_call (
232+ test_call , async_call , mock_context_1
233+ )
234+ assert async_call_ref .token is not None
235+
236+ # lose the connection
237+ await result_cache .session_close_hook (mock_session_1 )
238+
239+ # reconnect before keep_alive_timeout
240+ time = 0.5
241+ await result_cache .join_call (
242+ req = types .JoinCallToolAsyncRequest (
243+ method = "tools/async/join" ,
244+ params = types .JoinCallToolRequestParams (
245+ token = async_call_ref .token ,
246+ _meta = types .RequestParams .Meta (progressToken = "test" ),
247+ ),
248+ ),
249+ ctx = mock_context_2 ,
250+ )
251+
252+ result = await result_cache .get_result (
253+ types .GetToolAsyncResultRequest (
254+ method = "tools/async/get" ,
255+ params = types .GetToolAsyncResultRequestParams (
256+ token = async_call_ref .token
257+ ),
258+ )
259+ )
260+
261+ # should successfully read data
262+ assert not result .isError , str (result )
263+ assert len (result .content ) == 1
264+ assert type (result .content [0 ]) is types .TextContent
265+ assert result .content [0 ].text == "test"
266+
267+ # lose connection a second time
268+
269+ await result_cache .session_close_hook (mock_session_2 )
270+
271+ time = 2
272+
273+ # reconnect after the keep_alive_timeout
274+
275+ await result_cache .join_call (
276+ req = types .JoinCallToolAsyncRequest (
277+ method = "tools/async/join" ,
278+ params = types .JoinCallToolRequestParams (
279+ token = async_call_ref .token ,
280+ _meta = types .RequestParams .Meta (progressToken = "test" ),
281+ ),
282+ ),
283+ ctx = mock_context_3 ,
284+ )
285+
286+ result = await result_cache .get_result (
287+ types .GetToolAsyncResultRequest (
288+ method = "tools/async/get" ,
289+ params = types .GetToolAsyncResultRequestParams (
290+ token = async_call_ref .token
291+ ),
292+ )
293+ )
294+
295+ # now token should be expired
296+ assert result .isError , str (result )
297+ assert len (result .content ) == 1
298+ assert type (result .content [0 ]) is types .TextContent
299+ assert result .content [0 ].text == "Unknown async token"
0 commit comments