Skip to content

Commit 88a7f26

Browse files
committed
Allow task polling before creation notification arrives
1 parent 340fcad commit 88a7f26

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

src/mcp/shared/request.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import anyio
88
from pydantic import BaseModel
99

10+
from mcp.shared.exceptions import McpError
1011
from mcp.shared.task import is_terminal
11-
from mcp.types import GetTaskResult
12+
from mcp.types import INVALID_PARAMS, GetTaskResult
1213

1314
if TYPE_CHECKING:
1415
from mcp.shared.session import BaseSession
@@ -110,9 +111,15 @@ async def wrapper(task: Callable[[], Awaitable[ReceiveResultT]]):
110111

111112
async def _wait_for_result_task() -> ReceiveResultT:
112113
assert self.task_id
113-
return await self._task_handler(self.task_id, on_task_created, on_task_status)
114+
return await self._task_handler(self.task_id, on_task_status)
115+
116+
async def _wait_for_task_creation() -> None:
117+
# Wait for task creation notification
118+
await self.task_created_handle
119+
await on_task_created()
114120

115121
async with anyio.create_task_group() as tg:
122+
tg.start_soon(_wait_for_task_creation)
116123
tg.start_soon(wrapper, _wait_for_result_task)
117124
tg.start_soon(wrapper, self._wait_for_result)
118125

@@ -141,15 +148,13 @@ async def _wait_for_result(self) -> ReceiveResultT:
141148
async def _task_handler(
142149
self,
143150
task_id: str,
144-
on_task_created: Callable[[], Awaitable[None]],
145151
on_task_status: Callable[[GetTaskResult], Awaitable[None]],
146152
) -> ReceiveResultT:
147153
"""
148154
Encapsulate polling for a result, calling on_task_status after querying the task.
149155
150156
Args:
151157
task_id: The task ID to poll
152-
on_task_created: Callback invoked when task is created
153158
on_task_status: Callback invoked on each status poll
154159
155160
Returns:
@@ -158,14 +163,16 @@ async def _task_handler(
158163
Raises:
159164
Exception: If task polling or result retrieval fails
160165
"""
161-
# Wait for task creation notification
162-
await self.task_created_handle
163-
await on_task_created()
164-
165166
# Poll for completion
166167
task: GetTaskResult
167168
while True:
168-
task = await self.session.get_task(task_id)
169+
try:
170+
task = await self.session.get_task(task_id)
171+
except McpError as e:
172+
if e.error.code == INVALID_PARAMS:
173+
# Task may not exist yet
174+
continue
175+
raise
169176
await on_task_status(task)
170177

171178
if is_terminal(task.status):

0 commit comments

Comments
 (0)