Skip to content

Commit fed9e70

Browse files
committed
add test for keep_alive expiry
1 parent 234ed4c commit fed9e70

File tree

2 files changed

+210
-6
lines changed

2 files changed

+210
-6
lines changed

src/mcp/server/lowlevel/result_cache.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
import time
12
from collections.abc import Awaitable, Callable
23
from concurrent.futures import Future
34
from dataclasses import dataclass, field
45
from logging import getLogger
5-
from time import time
66
from types import TracebackType
77
from typing import Any
88
from uuid import uuid4
@@ -23,12 +23,21 @@
2323
@dataclass
2424
class InProgress:
2525
token: str
26+
timer: Callable[[], float]
2627
user: AuthenticatedUser | None = None
2728
future: Future[types.CallToolResult] | None = None
2829
sessions: dict[int, ServerSession] = field(default_factory=lambda: {})
2930
session_progress: dict[int, types.ProgressToken | None] = field(
3031
default_factory=lambda: {}
3132
)
33+
keep_alive: int | None = None
34+
keep_alive_start: int | None = None
35+
36+
def is_expired(self):
37+
if self.keep_alive_start is None or self.keep_alive is None:
38+
return False
39+
else:
40+
return int(self.timer()) > self.keep_alive_start + self.keep_alive
3241

3342

3443
class ResultCache:
@@ -54,11 +63,17 @@ class ResultCache:
5463
_session_lookup: dict[int, types.AsyncToken]
5564
_portal: BlockingPortal
5665

57-
def __init__(self, max_size: int, max_keep_alive: int):
66+
def __init__(
67+
self,
68+
max_size: int,
69+
max_keep_alive: int,
70+
timer: Callable[[], float] = time.monotonic,
71+
):
5872
self._max_size = max_size
5973
self._max_keep_alive = max_keep_alive
6074
self._in_progress = {}
6175
self._session_lookup = {}
76+
self._timer = timer
6277
self._portal_provider = BlockingPortalProvider()
6378

6479
async def __aenter__(self):
@@ -105,6 +120,7 @@ async def call_tool():
105120
in_progress.user = user_context.get()
106121
session_id = id(ctx.session)
107122
in_progress.sessions[session_id] = ctx.session
123+
in_progress.keep_alive = timeout
108124
if req.params.meta is not None:
109125
progress_token = req.params.meta.progressToken
110126
else:
@@ -114,7 +130,7 @@ async def call_tool():
114130
in_progress.future = self._portal.start_task_soon(call_tool)
115131
result = types.CallToolAsyncResult(
116132
token=in_progress.token,
117-
recieved=round(time()),
133+
recieved=round(self._timer()),
118134
keepAlive=timeout,
119135
accepted=True,
120136
)
@@ -176,6 +192,15 @@ async def get_result(self, req: types.GetToolAsyncResultRequest):
176192
if in_progress.user == user_context.get():
177193
assert in_progress.future is not None
178194
# TODO add timeout to get async result
195+
if in_progress.is_expired():
196+
self._portal.start_task_soon(self._expire)
197+
return types.CallToolResult(
198+
content=[
199+
types.TextContent(type="text", text="Unknown async token")
200+
],
201+
isError=True,
202+
)
203+
179204
try:
180205
result = in_progress.future.result(1)
181206
logger.debug(f"Found result {result}")
@@ -235,7 +260,12 @@ async def session_close_hook(self, session: ServerSession):
235260
if found is None:
236261
logger.warning("No session found")
237262
if len(in_progress.sessions) == 0:
238-
self._in_progress.pop(dropped, None)
263+
in_progress.keep_alive_start = int(self._timer())
264+
265+
async def _expire(self):
266+
for in_progress in self._in_progress.values():
267+
if in_progress.is_expired():
268+
self._in_progress.pop(in_progress.token, None)
239269
assert in_progress.future is not None
240270
logger.debug("Cancelled in progress future")
241271
in_progress.future.cancel()
@@ -251,6 +281,6 @@ async def _new_in_progress(self) -> InProgress:
251281
# for context
252282
token = str(uuid4())
253283
if token not in self._in_progress:
254-
new_in_progress = InProgress(token)
284+
new_in_progress = InProgress(token, self._timer)
255285
self._in_progress[token] = new_in_progress
256286
return new_in_progress

tests/server/lowlevel/test_result_cache.py

Lines changed: 175 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult:
6868
mock_context_2 = Mock()
6969

7070
mock_context_2.session = mock_session_2
71-
mock_session_2.send_progress_notification.result = None
7271

7372
result_cache = ResultCache(max_size=1, max_keep_alive=1)
7473
async with AsyncExitStack() as stack:
@@ -123,3 +122,178 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult:
123122
message=None,
124123
resource_uri=None,
125124
)
125+
126+
127+
@pytest.mark.anyio
128+
async def test_async_call_keep_alive():
129+
"""Tests async call keep alive"""
130+
131+
async def test_call(call: types.CallToolRequest) -> types.ServerResult:
132+
return types.ServerResult(
133+
types.CallToolResult(content=[types.TextContent(type="text", text="test")])
134+
)
135+
136+
async_call = types.CallToolAsyncRequest(
137+
method="tools/async/call", params=types.CallToolAsyncRequestParams(name="test")
138+
)
139+
140+
mock_session_1 = AsyncMock()
141+
mock_context_1 = Mock()
142+
mock_context_1.session = mock_session_1
143+
144+
mock_session_2 = AsyncMock()
145+
mock_context_2 = Mock()
146+
147+
mock_context_2.session = mock_session_2
148+
149+
result_cache = ResultCache(max_size=1, max_keep_alive=10)
150+
async with AsyncExitStack() as stack:
151+
await stack.enter_async_context(result_cache)
152+
async_call_ref = await result_cache.start_call(
153+
test_call, async_call, mock_context_1
154+
)
155+
assert async_call_ref.token is not None
156+
157+
await result_cache.session_close_hook(mock_session_1)
158+
159+
await result_cache.join_call(
160+
req=types.JoinCallToolAsyncRequest(
161+
method="tools/async/join",
162+
params=types.JoinCallToolRequestParams(
163+
token=async_call_ref.token,
164+
_meta=types.RequestParams.Meta(progressToken="test"),
165+
),
166+
),
167+
ctx=mock_context_2,
168+
)
169+
assert async_call_ref.token is not None
170+
await result_cache.notification_hook(
171+
session=mock_session_1,
172+
notification=types.ServerNotification(
173+
types.ProgressNotification(
174+
method="notifications/progress",
175+
params=types.ProgressNotificationParams(
176+
progressToken="test", progress=1
177+
),
178+
)
179+
),
180+
)
181+
182+
result = await result_cache.get_result(
183+
types.GetToolAsyncResultRequest(
184+
method="tools/async/get",
185+
params=types.GetToolAsyncResultRequestParams(
186+
token=async_call_ref.token
187+
),
188+
)
189+
)
190+
191+
assert not result.isError, str(result)
192+
assert not result.isPending
193+
assert len(result.content) == 1
194+
assert type(result.content[0]) is types.TextContent
195+
assert result.content[0].text == "test"
196+
197+
198+
@pytest.mark.anyio
199+
async def test_async_call_keep_alive_expired():
200+
"""Tests async call keep alive expiry"""
201+
202+
async def test_call(call: types.CallToolRequest) -> types.ServerResult:
203+
return types.ServerResult(
204+
types.CallToolResult(content=[types.TextContent(type="text", text="test")])
205+
)
206+
207+
async_call = types.CallToolAsyncRequest(
208+
method="tools/async/call", params=types.CallToolAsyncRequestParams(name="test")
209+
)
210+
211+
mock_session_1 = AsyncMock()
212+
mock_context_1 = Mock()
213+
mock_context_1.session = mock_session_1
214+
215+
mock_session_2 = AsyncMock()
216+
mock_context_2 = Mock()
217+
mock_context_2.session = mock_session_2
218+
219+
mock_session_3 = AsyncMock()
220+
mock_context_3 = Mock()
221+
mock_context_3.session = mock_session_3
222+
223+
time = 0.0
224+
225+
def test_timer():
226+
return time
227+
228+
result_cache = ResultCache(max_size=1, max_keep_alive=1, timer=test_timer)
229+
async with AsyncExitStack() as stack:
230+
await stack.enter_async_context(result_cache)
231+
async_call_ref = await result_cache.start_call(
232+
test_call, async_call, mock_context_1
233+
)
234+
assert async_call_ref.token is not None
235+
236+
# lose the connection
237+
await result_cache.session_close_hook(mock_session_1)
238+
239+
# reconnect before keep_alive_timeout
240+
time = 0.5
241+
await result_cache.join_call(
242+
req=types.JoinCallToolAsyncRequest(
243+
method="tools/async/join",
244+
params=types.JoinCallToolRequestParams(
245+
token=async_call_ref.token,
246+
_meta=types.RequestParams.Meta(progressToken="test"),
247+
),
248+
),
249+
ctx=mock_context_2,
250+
)
251+
252+
result = await result_cache.get_result(
253+
types.GetToolAsyncResultRequest(
254+
method="tools/async/get",
255+
params=types.GetToolAsyncResultRequestParams(
256+
token=async_call_ref.token
257+
),
258+
)
259+
)
260+
261+
# should successfully read data
262+
assert not result.isError, str(result)
263+
assert len(result.content) == 1
264+
assert type(result.content[0]) is types.TextContent
265+
assert result.content[0].text == "test"
266+
267+
# lose connection a second time
268+
269+
await result_cache.session_close_hook(mock_session_2)
270+
271+
time = 2
272+
273+
# reconnect after the keep_alive_timeout
274+
275+
await result_cache.join_call(
276+
req=types.JoinCallToolAsyncRequest(
277+
method="tools/async/join",
278+
params=types.JoinCallToolRequestParams(
279+
token=async_call_ref.token,
280+
_meta=types.RequestParams.Meta(progressToken="test"),
281+
),
282+
),
283+
ctx=mock_context_3,
284+
)
285+
286+
result = await result_cache.get_result(
287+
types.GetToolAsyncResultRequest(
288+
method="tools/async/get",
289+
params=types.GetToolAsyncResultRequestParams(
290+
token=async_call_ref.token
291+
),
292+
)
293+
)
294+
295+
# now token should be expired
296+
assert result.isError, str(result)
297+
assert len(result.content) == 1
298+
assert type(result.content[0]) is types.TextContent
299+
assert result.content[0].text == "Unknown async token"

0 commit comments

Comments
 (0)