77import anyio
88from pydantic import BaseModel
99
10+ from mcp .shared .exceptions import McpError
1011from mcp .shared .task import is_terminal
11- from mcp .types import GetTaskResult
12+ from mcp .types import INVALID_PARAMS , GetTaskResult
1213
1314if TYPE_CHECKING :
1415 from mcp .shared .session import BaseSession
@@ -110,9 +111,15 @@ async def wrapper(task: Callable[[], Awaitable[ReceiveResultT]]):
110111
111112 async def _wait_for_result_task () -> ReceiveResultT :
112113 assert self .task_id
113- return await self ._task_handler (self .task_id , on_task_created , on_task_status )
114+ return await self ._task_handler (self .task_id , on_task_status )
115+
116+ async def _wait_for_task_creation () -> None :
117+ # Wait for task creation notification
118+ await self .task_created_handle
119+ await on_task_created ()
114120
115121 async with anyio .create_task_group () as tg :
122+ tg .start_soon (_wait_for_task_creation )
116123 tg .start_soon (wrapper , _wait_for_result_task )
117124 tg .start_soon (wrapper , self ._wait_for_result )
118125
@@ -141,15 +148,13 @@ async def _wait_for_result(self) -> ReceiveResultT:
141148 async def _task_handler (
142149 self ,
143150 task_id : str ,
144- on_task_created : Callable [[], Awaitable [None ]],
145151 on_task_status : Callable [[GetTaskResult ], Awaitable [None ]],
146152 ) -> ReceiveResultT :
147153 """
148154 Encapsulate polling for a result, calling on_task_status after querying the task.
149155
150156 Args:
151157 task_id: The task ID to poll
152- on_task_created: Callback invoked when task is created
153158 on_task_status: Callback invoked on each status poll
154159
155160 Returns:
@@ -158,14 +163,16 @@ async def _task_handler(
158163 Raises:
159164 Exception: If task polling or result retrieval fails
160165 """
161- # Wait for task creation notification
162- await self .task_created_handle
163- await on_task_created ()
164-
165166 # Poll for completion
166167 task : GetTaskResult
167168 while True :
168- task = await self .session .get_task (task_id )
169+ try :
170+ task = await self .session .get_task (task_id )
171+ except McpError as e :
172+ if e .error .code == INVALID_PARAMS :
173+ # Task may not exist yet
174+ continue
175+ raise
169176 await on_task_status (task )
170177
171178 if is_terminal (task .status ):
0 commit comments