@@ -77,6 +77,7 @@ async def main():
7777
7878import anyio
7979import jsonschema
80+ from anyio .abc import TaskGroup
8081from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
8182from pydantic import AnyUrl
8283from typing_extensions import TypeVar
@@ -252,7 +253,7 @@ def decorator(
252253
253254 wrapper = create_call_wrapper (func , types .ListPromptsRequest )
254255
255- async def handler (req : types .ListPromptsRequest ):
256+ async def handler (req : types .ListPromptsRequest , _ : Any = None ):
256257 result = await wrapper (req )
257258 # Handle both old style (list[Prompt]) and new style (ListPromptsResult)
258259 if isinstance (result , types .ListPromptsResult ):
@@ -272,7 +273,7 @@ def decorator(
272273 ):
273274 logger .debug ("Registering handler for GetPromptRequest" )
274275
275- async def handler (req : types .GetPromptRequest ):
276+ async def handler (req : types .GetPromptRequest , _ : Any = None ):
276277 prompt_get = await func (req .params .name , req .params .arguments )
277278 return types .ServerResult (prompt_get )
278279
@@ -290,7 +291,7 @@ def decorator(
290291
291292 wrapper = create_call_wrapper (func , types .ListResourcesRequest )
292293
293- async def handler (req : types .ListResourcesRequest ):
294+ async def handler (req : types .ListResourcesRequest , _ : Any = None ):
294295 result = await wrapper (req )
295296 # Handle both old style (list[Resource]) and new style (ListResourcesResult)
296297 if isinstance (result , types .ListResourcesResult ):
@@ -308,7 +309,7 @@ def list_resource_templates(self):
308309 def decorator (func : Callable [[], Awaitable [list [types .ResourceTemplate ]]]):
309310 logger .debug ("Registering handler for ListResourceTemplatesRequest" )
310311
311- async def handler (_ : Any ):
312+ async def handler (_1 : Any , _2 : Any = None ):
312313 templates = await func ()
313314 return types .ServerResult (types .ListResourceTemplatesResult (resourceTemplates = templates ))
314315
@@ -323,7 +324,7 @@ def decorator(
323324 ):
324325 logger .debug ("Registering handler for ReadResourceRequest" )
325326
326- async def handler (req : types .ReadResourceRequest ):
327+ async def handler (req : types .ReadResourceRequest , _ : Any = None ):
327328 result = await func (req .params .uri )
328329
329330 def create_content (data : str | bytes , mime_type : str | None ):
@@ -379,7 +380,7 @@ def set_logging_level(self):
379380 def decorator (func : Callable [[types .LoggingLevel ], Awaitable [None ]]):
380381 logger .debug ("Registering handler for SetLevelRequest" )
381382
382- async def handler (req : types .SetLevelRequest ):
383+ async def handler (req : types .SetLevelRequest , _ : Any = None ):
383384 await func (req .params .level )
384385 return types .ServerResult (types .EmptyResult ())
385386
@@ -392,7 +393,7 @@ def subscribe_resource(self):
392393 def decorator (func : Callable [[AnyUrl ], Awaitable [None ]]):
393394 logger .debug ("Registering handler for SubscribeRequest" )
394395
395- async def handler (req : types .SubscribeRequest ):
396+ async def handler (req : types .SubscribeRequest , _ : Any = None ):
396397 await func (req .params .uri )
397398 return types .ServerResult (types .EmptyResult ())
398399
@@ -405,7 +406,7 @@ def unsubscribe_resource(self):
405406 def decorator (func : Callable [[AnyUrl ], Awaitable [None ]]):
406407 logger .debug ("Registering handler for UnsubscribeRequest" )
407408
408- async def handler (req : types .UnsubscribeRequest ):
409+ async def handler (req : types .UnsubscribeRequest , _ : Any = None ):
409410 await func (req .params .uri )
410411 return types .ServerResult (types .EmptyResult ())
411412
@@ -423,7 +424,7 @@ def decorator(
423424
424425 wrapper = create_call_wrapper (func , types .ListToolsRequest )
425426
426- async def handler (req : types .ListToolsRequest ):
427+ async def handler (req : types .ListToolsRequest , _ : Any = None ):
427428 result = await wrapper (req )
428429
429430 # Handle both old style (list[Tool]) and new style (ListToolsResult)
@@ -493,7 +494,7 @@ def decorator(
493494 ):
494495 logger .debug ("Registering handler for CallToolRequest" )
495496
496- async def handler (req : types .CallToolRequest ):
497+ async def handler (req : types .CallToolRequest , server_scope : TaskGroup ):
497498 try :
498499 tool_name = req .params .name
499500 arguments = req .params .arguments or {}
@@ -563,20 +564,20 @@ async def execute_async():
563564 logger .exception (f"Async execution failed for { tool_name } " )
564565 self .async_operations .fail_operation (operation .token , str (e ))
565566
566- async with anyio .create_task_group () as tg :
567- tg .start_soon (execute_async )
568-
569- # Return operation result with immediate content
570- logger .info (f"Returning async operation result for { tool_name } " )
571- return types .ServerResult (
572- types .CallToolResult (
573- content = immediate_content ,
574- operation = types .AsyncResultProperties (
575- token = operation .token ,
576- keepAlive = operation .keep_alive ,
577- ),
578- )
567+ # Dispatch in concurrency scope of the server to run between requests
568+ server_scope .start_soon (execute_async )
569+
570+ # Return operation result with immediate content
571+ logger .info (f"Returning async operation result for { tool_name } " )
572+ return types .ServerResult (
573+ types .CallToolResult (
574+ content = immediate_content ,
575+ operation = types .AsyncResultProperties (
576+ token = operation .token ,
577+ keepAlive = operation .keep_alive ,
578+ ),
579579 )
580+ )
580581
581582 # tool call
582583 results = await func (tool_name , arguments )
@@ -690,7 +691,7 @@ def decorator(
690691 ):
691692 logger .debug ("Registering handler for ProgressNotification" )
692693
693- async def handler (req : types .ProgressNotification ):
694+ async def handler (req : types .ProgressNotification , _ : Any = None ):
694695 await func (
695696 req .params .progressToken ,
696697 req .params .progress ,
@@ -718,7 +719,7 @@ def decorator(
718719 ):
719720 logger .debug ("Registering handler for CompleteRequest" )
720721
721- async def handler (req : types .CompleteRequest ):
722+ async def handler (req : types .CompleteRequest , _ : Any = None ):
722723 completion = await func (req .params .ref , req .params .argument , req .params .context )
723724 return types .ServerResult (
724725 types .CompleteResult (
@@ -754,7 +755,7 @@ def get_operation_status(self):
754755 def decorator (func : Callable [[str ], Awaitable [types .GetOperationStatusResult ]]):
755756 logger .debug ("Registering handler for GetOperationStatusRequest" )
756757
757- async def handler (req : types .GetOperationStatusRequest ):
758+ async def handler (req : types .GetOperationStatusRequest , _ : Any = None ):
758759 # Validate token and get operation
759760 operation = self ._validate_operation_token (req .params .token )
760761
@@ -776,7 +777,7 @@ def get_operation_result(self):
776777 def decorator (func : Callable [[str ], Awaitable [types .GetOperationPayloadResult ]]):
777778 logger .debug ("Registering handler for GetOperationPayloadRequest" )
778779
779- async def handler (req : types .GetOperationPayloadRequest ):
780+ async def handler (req : types .GetOperationPayloadRequest , _ : Any = None ):
780781 # Validate token and get operation
781782 operation = self ._validate_operation_token (req .params .token )
782783
@@ -878,6 +879,7 @@ async def run(
878879 session ,
879880 lifespan_context ,
880881 raise_exceptions ,
882+ tg ,
881883 )
882884 finally :
883885 # Cancel session operations and stop cleanup task
@@ -892,13 +894,16 @@ async def _handle_message(
892894 session : ServerSession ,
893895 lifespan_context : LifespanResultT ,
894896 raise_exceptions : bool = False ,
897+ server_scope : TaskGroup | None = None ,
895898 ):
896899 with warnings .catch_warnings (record = True ) as w :
897900 # TODO(Marcelo): We should be checking if message is Exception here.
898901 match message : # type: ignore[reportMatchNotExhaustive]
899902 case RequestResponder (request = types .ClientRequest (root = req )) as responder :
900903 with responder :
901- await self ._handle_request (message , req , session , lifespan_context , raise_exceptions )
904+ await self ._handle_request (
905+ message , req , session , lifespan_context , raise_exceptions , server_scope
906+ )
902907 case types .ClientNotification (root = notify ):
903908 await self ._handle_notification (notify )
904909
@@ -912,6 +917,7 @@ async def _handle_request(
912917 session : ServerSession ,
913918 lifespan_context : LifespanResultT ,
914919 raise_exceptions : bool ,
920+ server_scope : TaskGroup | None = None ,
915921 ):
916922 logger .info ("Processing request of type %s" , type (req ).__name__ )
917923 if handler := self .request_handlers .get (type (req )): # type: ignore
@@ -936,7 +942,7 @@ async def _handle_request(
936942 request = request_data ,
937943 )
938944 )
939- response = await handler (req )
945+ response = await handler (req , server_scope )
940946
941947 # Track async operations for cancellation
942948 if isinstance (req , types .CallToolRequest ):
@@ -985,5 +991,5 @@ async def _handle_notification(self, notify: Any):
985991 logger .exception ("Uncaught exception in notification handler" )
986992
987993
988- async def _ping_handler (request : types .PingRequest ) -> types .ServerResult :
994+ async def _ping_handler (request : types .PingRequest , _ : Any = None ) -> types .ServerResult :
989995 return types .ServerResult (types .EmptyResult ())
0 commit comments