2323from mcp .shared .message import ClientMessageMetadata , SessionMessage
2424from mcp .types import (
2525 ErrorData ,
26+ InitializeResult ,
2627 JSONRPCError ,
2728 JSONRPCMessage ,
2829 JSONRPCNotification ,
4041GetSessionIdCallback = Callable [[], str | None ]
4142
4243MCP_SESSION_ID = "mcp-session-id"
44+ MCP_PROTOCOL_VERSION = "mcp-protocol-version"
4345LAST_EVENT_ID = "last-event-id"
4446CONTENT_TYPE = "content-type"
4547ACCEPT = "Accept"
@@ -98,17 +100,20 @@ def __init__(
98100 )
99101 self .auth = auth
100102 self .session_id = None
103+ self .protocol_version = None
101104 self .request_headers = {
102105 ACCEPT : f"{ JSON } , { SSE } " ,
103106 CONTENT_TYPE : JSON ,
104107 ** self .headers ,
105108 }
106109
107- def _update_headers_with_session (self , base_headers : dict [str , str ]) -> dict [str , str ]:
108- """Update headers with session ID if available."""
110+ def _prepare_request_headers (self , base_headers : dict [str , str ]) -> dict [str , str ]:
111+ """Update headers with session ID and protocol version if available."""
109112 headers = base_headers .copy ()
110113 if self .session_id :
111114 headers [MCP_SESSION_ID ] = self .session_id
115+ if self .protocol_version :
116+ headers [MCP_PROTOCOL_VERSION ] = self .protocol_version
112117 return headers
113118
114119 def _is_initialization_request (self , message : JSONRPCMessage ) -> bool :
@@ -129,19 +134,39 @@ def _maybe_extract_session_id_from_response(
129134 self .session_id = new_session_id
130135 logger .info (f"Received session ID: { self .session_id } " )
131136
137+ def _maybe_extract_protocol_version_from_message (
138+ self ,
139+ message : JSONRPCMessage ,
140+ ) -> None :
141+ """Extract protocol version from initialization response message."""
142+ if isinstance (message .root , JSONRPCResponse ) and message .root .result :
143+ try :
144+ # Parse the result as InitializeResult for type safety
145+ init_result = InitializeResult .model_validate (message .root .result )
146+ self .protocol_version = str (init_result .protocolVersion )
147+ logger .info (f"Negotiated protocol version: { self .protocol_version } " )
148+ except Exception as exc :
149+ logger .warning (f"Failed to parse initialization response as InitializeResult: { exc } " )
150+ logger .warning (f"Raw result: { message .root .result } " )
151+
132152 async def _handle_sse_event (
133153 self ,
134154 sse : ServerSentEvent ,
135155 read_stream_writer : StreamWriter ,
136156 original_request_id : RequestId | None = None ,
137157 resumption_callback : Callable [[str ], Awaitable [None ]] | None = None ,
158+ is_initialization : bool = False ,
138159 ) -> bool :
139160 """Handle an SSE event, returning True if the response is complete."""
140161 if sse .event == "message" :
141162 try :
142163 message = JSONRPCMessage .model_validate_json (sse .data )
143164 logger .debug (f"SSE message: { message } " )
144165
166+ # Extract protocol version from initialization response
167+ if is_initialization :
168+ self ._maybe_extract_protocol_version_from_message (message )
169+
145170 # If this is a response and we have original_request_id, replace it
146171 if original_request_id is not None and isinstance (message .root , JSONRPCResponse | JSONRPCError ):
147172 message .root .id = original_request_id
@@ -175,7 +200,7 @@ async def handle_get_stream(
175200 if not self .session_id :
176201 return
177202
178- headers = self ._update_headers_with_session (self .request_headers )
203+ headers = self ._prepare_request_headers (self .request_headers )
179204
180205 async with aconnect_sse (
181206 client ,
@@ -195,7 +220,7 @@ async def handle_get_stream(
195220
196221 async def _handle_resumption_request (self , ctx : RequestContext ) -> None :
197222 """Handle a resumption request using GET with SSE."""
198- headers = self ._update_headers_with_session (ctx .headers )
223+ headers = self ._prepare_request_headers (ctx .headers )
199224 if ctx .metadata and ctx .metadata .resumption_token :
200225 headers [LAST_EVENT_ID ] = ctx .metadata .resumption_token
201226 else :
@@ -228,7 +253,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
228253
229254 async def _handle_post_request (self , ctx : RequestContext ) -> None :
230255 """Handle a POST request with response processing."""
231- headers = self ._update_headers_with_session (ctx .headers )
256+ headers = self ._prepare_request_headers (ctx .headers )
232257 message = ctx .session_message .message
233258 is_initialization = self ._is_initialization_request (message )
234259
@@ -257,9 +282,9 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
257282 content_type = response .headers .get (CONTENT_TYPE , "" ).lower ()
258283
259284 if content_type .startswith (JSON ):
260- await self ._handle_json_response (response , ctx .read_stream_writer )
285+ await self ._handle_json_response (response , ctx .read_stream_writer , is_initialization )
261286 elif content_type .startswith (SSE ):
262- await self ._handle_sse_response (response , ctx )
287+ await self ._handle_sse_response (response , ctx , is_initialization )
263288 else :
264289 await self ._handle_unexpected_content_type (
265290 content_type ,
@@ -270,18 +295,29 @@ async def _handle_json_response(
270295 self ,
271296 response : httpx .Response ,
272297 read_stream_writer : StreamWriter ,
298+ is_initialization : bool = False ,
273299 ) -> None :
274300 """Handle JSON response from the server."""
275301 try :
276302 content = await response .aread ()
277303 message = JSONRPCMessage .model_validate_json (content )
304+
305+ # Extract protocol version from initialization response
306+ if is_initialization :
307+ self ._maybe_extract_protocol_version_from_message (message )
308+
278309 session_message = SessionMessage (message )
279310 await read_stream_writer .send (session_message )
280311 except Exception as exc :
281312 logger .error (f"Error parsing JSON response: { exc } " )
282313 await read_stream_writer .send (exc )
283314
284- async def _handle_sse_response (self , response : httpx .Response , ctx : RequestContext ) -> None :
315+ async def _handle_sse_response (
316+ self ,
317+ response : httpx .Response ,
318+ ctx : RequestContext ,
319+ is_initialization : bool = False ,
320+ ) -> None :
285321 """Handle SSE response from the server."""
286322 try :
287323 event_source = EventSource (response )
@@ -292,6 +328,7 @@ async def _handle_sse_response(self, response: httpx.Response, ctx: RequestConte
292328 sse ,
293329 ctx .read_stream_writer ,
294330 resumption_callback = (ctx .metadata .on_resumption_token_update if ctx .metadata else None ),
331+ is_initialization = is_initialization ,
295332 )
296333 # If the SSE event indicates completion, like returning respose/error
297334 # break the loop
@@ -388,7 +425,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None:
388425 return
389426
390427 try :
391- headers = self ._update_headers_with_session (self .request_headers )
428+ headers = self ._prepare_request_headers (self .request_headers )
392429 response = await client .delete (self .url , headers = headers )
393430
394431 if response .status_code == 405 :
0 commit comments