77import httpx
88from anyio .abc import TaskStatus
99from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
10- from exceptiongroup import BaseExceptionGroup , catch
1110from httpx_sse import aconnect_sse
1211
1312import mcp .types as types
@@ -20,11 +19,6 @@ def remove_request_params(url: str) -> str:
2019 return urljoin (url , urlparse (url ).path )
2120
2221
23- def handle_exception (exc : BaseExceptionGroup [Exception ]) -> str :
24- """Handle ExceptionGroup and Exceptions for Client transport for SSE"""
25- messages = "; " .join (str (e ) for e in exc .exceptions )
26- raise Exception (f"TaskGroup failed with: { messages } " ) from None
27-
2822@asynccontextmanager
2923async def sse_client (
3024 url : str ,
@@ -47,117 +41,114 @@ async def sse_client(
4741 read_stream_writer , read_stream = anyio .create_memory_object_stream (0 )
4842 write_stream , write_stream_reader = anyio .create_memory_object_stream (0 )
4943
50- with catch ({Exception : handle_exception }):
51- async with anyio .create_task_group () as tg :
52- try :
53- logger .info (f"Connecting to SSE endpoint: { remove_request_params (url )} " )
54- async with httpx .AsyncClient (headers = headers ) as client :
55- async with aconnect_sse (
56- client ,
57- "GET" ,
58- url ,
59- timeout = httpx .Timeout (timeout , read = sse_read_timeout ),
60- ) as event_source :
61- event_source .response .raise_for_status ()
62- logger .debug ("SSE connection established" )
63-
64- async def sse_reader (
65- task_status : TaskStatus [str ] = anyio .TASK_STATUS_IGNORED ,
66- ):
67- try :
68- async for sse in event_source .aiter_sse ():
69- logger .debug (f"Received SSE event: { sse .event } " )
70- match sse .event :
71- case "endpoint" :
72- endpoint_url = urljoin (url , sse .data )
73- logger .info (
74- f"Received endpoint URL: { endpoint_url } "
44+ async with anyio .create_task_group () as tg :
45+ try :
46+ logger .info (f"Connecting to SSE endpoint: { remove_request_params (url )} " )
47+ async with httpx .AsyncClient (headers = headers ) as client :
48+ async with aconnect_sse (
49+ client ,
50+ "GET" ,
51+ url ,
52+ timeout = httpx .Timeout (timeout , read = sse_read_timeout ),
53+ ) as event_source :
54+ event_source .response .raise_for_status ()
55+ logger .debug ("SSE connection established" )
56+
57+ async def sse_reader (
58+ task_status : TaskStatus [str ] = anyio .TASK_STATUS_IGNORED ,
59+ ):
60+ try :
61+ async for sse in event_source .aiter_sse ():
62+ logger .debug (f"Received SSE event: { sse .event } " )
63+ match sse .event :
64+ case "endpoint" :
65+ endpoint_url = urljoin (url , sse .data )
66+ logger .info (
67+ f"Received endpoint URL: { endpoint_url } "
68+ )
69+
70+ url_parsed = urlparse (url )
71+ endpoint_parsed = urlparse (endpoint_url )
72+ if (
73+ url_parsed .netloc != endpoint_parsed .netloc
74+ or url_parsed .scheme
75+ != endpoint_parsed .scheme
76+ ):
77+ error_msg = (
78+ "Endpoint origin does not match "
79+ f"connection origin: { endpoint_url } "
7580 )
81+ logger .error (error_msg )
82+ raise ValueError (error_msg )
83+
84+ task_status .started (endpoint_url )
7685
77- url_parsed = urlparse (url )
78- endpoint_parsed = urlparse (endpoint_url )
79- if (
80- url_parsed .netloc
81- != endpoint_parsed .netloc
82- or url_parsed .scheme
83- != endpoint_parsed .scheme
84- ):
85- error_msg = (
86- "Endpoint origin does not match "
87- f"connection origin: { endpoint_url } "
88- )
89- logger .error (error_msg )
90- raise ValueError (error_msg )
91-
92- task_status .started (endpoint_url )
93-
94- case "message" :
95- try :
96- message = types .JSONRPCMessage .model_validate_json ( # noqa: E501
97- sse .data
98- )
99- logger .debug (
100- "Received server message: "
101- f"{ message } "
102-
103- )
104- except Exception as exc :
105- logger .error (
106- "Error parsing server message: "
107- f"{ exc } "
108- )
109- await read_stream_writer .send (exc )
110- continue
111-
112- session_message = SessionMessage (message )
113- await read_stream_writer .send (
114- session_message
86+ case "message" :
87+ try :
88+ message = types .JSONRPCMessage .model_validate_json ( # noqa: E501
89+ sse .data
11590 )
116- case _:
117- logger .warning (
118- f"Unknown SSE event: { sse .event } "
91+ logger .debug (
92+ f"Received server message: { message } "
11993 )
120- except Exception as exc :
121- logger .error (f"Error in sse_reader: { exc } " )
122- await read_stream_writer .send (exc )
123- finally :
124- await read_stream_writer .aclose ()
125-
126- async def post_writer (endpoint_url : str ):
127- try :
128- async with write_stream_reader :
129- async for session_message in write_stream_reader :
130- logger .debug (
131- f"Sending client message: { session_message } "
132- )
133- response = await client .post (
134- endpoint_url ,
135- json = session_message .message .model_dump (
136- by_alias = True ,
137- mode = "json" ,
138- exclude_none = True ,
139- ),
94+ except Exception as exc :
95+ logger .error (
96+ f"Error parsing server message: { exc } "
97+ )
98+ await read_stream_writer .send (exc )
99+ continue
100+
101+ session_message = SessionMessage (
102+ message = message
140103 )
141- response . raise_for_status ( )
142- logger . debug (
143- "Client message sent successfully: "
144- f"{ response . status_code } "
104+ await read_stream_writer . send ( session_message )
105+ case _:
106+ logger . warning (
107+ f"Unknown SSE event: { sse . event } "
145108 )
146- except Exception as exc :
147- logger .error (f"Error in post_writer: { exc } " )
148- finally :
149- await write_stream .aclose ()
150-
151- endpoint_url = await tg .start (sse_reader )
152- logger .info (
153- f"Starting post writer with endpoint URL: { endpoint_url } "
154- )
155- tg .start_soon (post_writer , endpoint_url )
109+ except Exception as exc :
110+ logger .error (f"Error in sse_reader: { exc } " )
111+ await read_stream_writer .send (exc )
112+ finally :
113+ await read_stream_writer .aclose ()
156114
115+ async def post_writer (endpoint_url : str ):
157116 try :
158- yield read_stream , write_stream
117+ async with write_stream_reader :
118+ async for session_message in write_stream_reader :
119+ logger .debug (
120+ f"Sending client message: { session_message } "
121+ )
122+ response = await client .post (
123+ endpoint_url ,
124+ json = session_message .message .model_dump (
125+ by_alias = True ,
126+ mode = "json" ,
127+ exclude_none = True ,
128+ ),
129+ )
130+ response .raise_for_status ()
131+ logger .debug (
132+ "Client message sent successfully: "
133+ f"{ response .status_code } "
134+ )
135+ except Exception as exc :
136+ logger .error (f"Error in post_writer: { exc } " )
159137 finally :
160- tg .cancel_scope .cancel ()
161- finally :
162- await read_stream_writer .aclose ()
163- await write_stream .aclose ()
138+ await write_stream .aclose ()
139+
140+ endpoint_url = await tg .start (sse_reader )
141+ logger .info (
142+ f"Starting post writer with endpoint URL: { endpoint_url } "
143+ )
144+ tg .start_soon (post_writer , endpoint_url )
145+
146+ try :
147+ yield read_stream , write_stream
148+ finally :
149+ tg .cancel_scope .cancel ()
150+ finally :
151+ await read_stream_writer .aclose ()
152+ await write_stream .aclose ()
153+ await read_stream .aclose ()
154+ await write_stream_reader .aclose ()
0 commit comments