Skip to content

Commit 5fec84c

Browse files
committed
Merge branch 'main' into trio-support
2 parents 0ce5b79 + 1bb16fa commit 5fec84c

File tree

7 files changed

+268
-46
lines changed

7 files changed

+268
-46
lines changed

.github/workflows/check-lock.yml

Lines changed: 0 additions & 25 deletions
This file was deleted.

src/mcp/client/auth.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717
import anyio
1818
import httpx
1919

20-
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthMetadata, OAuthToken
20+
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
21+
from mcp.shared.auth import (
22+
OAuthClientInformationFull,
23+
OAuthClientMetadata,
24+
OAuthMetadata,
25+
OAuthToken,
26+
)
2127
from mcp.types import LATEST_PROTOCOL_VERSION
2228

2329
logger = logging.getLogger(__name__)
@@ -121,7 +127,7 @@ async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | Non
121127
# Extract base URL per MCP spec
122128
auth_base_url = self._get_authorization_base_url(server_url)
123129
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
124-
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}
130+
headers = {MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}
125131

126132
async with httpx.AsyncClient() as client:
127133
try:

src/mcp/client/streamable_http.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from mcp.shared.message import ClientMessageMetadata, SessionMessage
2424
from mcp.types import (
2525
ErrorData,
26+
InitializeResult,
2627
JSONRPCError,
2728
JSONRPCMessage,
2829
JSONRPCNotification,
@@ -40,6 +41,7 @@
4041
GetSessionIdCallback = Callable[[], str | None]
4142

4243
MCP_SESSION_ID = "mcp-session-id"
44+
MCP_PROTOCOL_VERSION = "mcp-protocol-version"
4345
LAST_EVENT_ID = "last-event-id"
4446
CONTENT_TYPE = "content-type"
4547
ACCEPT = "Accept"
@@ -98,17 +100,20 @@ def __init__(
98100
)
99101
self.auth = auth
100102
self.session_id = None
103+
self.protocol_version = None
101104
self.request_headers = {
102105
ACCEPT: f"{JSON}, {SSE}",
103106
CONTENT_TYPE: JSON,
104107
**self.headers,
105108
}
106109

107-
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
108-
"""Update headers with session ID if available."""
110+
def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, str]:
111+
"""Update headers with session ID and protocol version if available."""
109112
headers = base_headers.copy()
110113
if self.session_id:
111114
headers[MCP_SESSION_ID] = self.session_id
115+
if self.protocol_version:
116+
headers[MCP_PROTOCOL_VERSION] = self.protocol_version
112117
return headers
113118

114119
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
@@ -129,19 +134,39 @@ def _maybe_extract_session_id_from_response(
129134
self.session_id = new_session_id
130135
logger.info(f"Received session ID: {self.session_id}")
131136

137+
def _maybe_extract_protocol_version_from_message(
138+
self,
139+
message: JSONRPCMessage,
140+
) -> None:
141+
"""Extract protocol version from initialization response message."""
142+
if isinstance(message.root, JSONRPCResponse) and message.root.result:
143+
try:
144+
# Parse the result as InitializeResult for type safety
145+
init_result = InitializeResult.model_validate(message.root.result)
146+
self.protocol_version = str(init_result.protocolVersion)
147+
logger.info(f"Negotiated protocol version: {self.protocol_version}")
148+
except Exception as exc:
149+
logger.warning(f"Failed to parse initialization response as InitializeResult: {exc}")
150+
logger.warning(f"Raw result: {message.root.result}")
151+
132152
async def _handle_sse_event(
133153
self,
134154
sse: ServerSentEvent,
135155
read_stream_writer: StreamWriter,
136156
original_request_id: RequestId | None = None,
137157
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
158+
is_initialization: bool = False,
138159
) -> bool:
139160
"""Handle an SSE event, returning True if the response is complete."""
140161
if sse.event == "message":
141162
try:
142163
message = JSONRPCMessage.model_validate_json(sse.data)
143164
logger.debug(f"SSE message: {message}")
144165

166+
# Extract protocol version from initialization response
167+
if is_initialization:
168+
self._maybe_extract_protocol_version_from_message(message)
169+
145170
# If this is a response and we have original_request_id, replace it
146171
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
147172
message.root.id = original_request_id
@@ -175,7 +200,7 @@ async def handle_get_stream(
175200
if not self.session_id:
176201
return
177202

178-
headers = self._update_headers_with_session(self.request_headers)
203+
headers = self._prepare_request_headers(self.request_headers)
179204

180205
async with aconnect_sse(
181206
client,
@@ -195,7 +220,7 @@ async def handle_get_stream(
195220

196221
async def _handle_resumption_request(self, ctx: RequestContext) -> None:
197222
"""Handle a resumption request using GET with SSE."""
198-
headers = self._update_headers_with_session(ctx.headers)
223+
headers = self._prepare_request_headers(ctx.headers)
199224
if ctx.metadata and ctx.metadata.resumption_token:
200225
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
201226
else:
@@ -228,7 +253,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
228253

229254
async def _handle_post_request(self, ctx: RequestContext) -> None:
230255
"""Handle a POST request with response processing."""
231-
headers = self._update_headers_with_session(ctx.headers)
256+
headers = self._prepare_request_headers(ctx.headers)
232257
message = ctx.session_message.message
233258
is_initialization = self._is_initialization_request(message)
234259

@@ -257,9 +282,9 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
257282
content_type = response.headers.get(CONTENT_TYPE, "").lower()
258283

259284
if content_type.startswith(JSON):
260-
await self._handle_json_response(response, ctx.read_stream_writer)
285+
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
261286
elif content_type.startswith(SSE):
262-
await self._handle_sse_response(response, ctx)
287+
await self._handle_sse_response(response, ctx, is_initialization)
263288
else:
264289
await self._handle_unexpected_content_type(
265290
content_type,
@@ -270,18 +295,29 @@ async def _handle_json_response(
270295
self,
271296
response: httpx.Response,
272297
read_stream_writer: StreamWriter,
298+
is_initialization: bool = False,
273299
) -> None:
274300
"""Handle JSON response from the server."""
275301
try:
276302
content = await response.aread()
277303
message = JSONRPCMessage.model_validate_json(content)
304+
305+
# Extract protocol version from initialization response
306+
if is_initialization:
307+
self._maybe_extract_protocol_version_from_message(message)
308+
278309
session_message = SessionMessage(message)
279310
await read_stream_writer.send(session_message)
280311
except Exception as exc:
281312
logger.error(f"Error parsing JSON response: {exc}")
282313
await read_stream_writer.send(exc)
283314

284-
async def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None:
315+
async def _handle_sse_response(
316+
self,
317+
response: httpx.Response,
318+
ctx: RequestContext,
319+
is_initialization: bool = False,
320+
) -> None:
285321
"""Handle SSE response from the server."""
286322
try:
287323
event_source = EventSource(response)
@@ -292,6 +328,7 @@ async def _handle_sse_response(self, response: httpx.Response, ctx: RequestConte
292328
sse,
293329
ctx.read_stream_writer,
294330
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
331+
is_initialization=is_initialization,
295332
)
296333
# If the SSE event indicates completion, like returning respose/error
297334
# break the loop
@@ -388,7 +425,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None:
388425
return
389426

390427
try:
391-
headers = self._update_headers_with_session(self.request_headers)
428+
headers = self._prepare_request_headers(self.request_headers)
392429
response = await client.delete(self.url, headers=headers)
393430

394431
if response.status_code == 405:

src/mcp/server/auth/routes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
1717
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
1818
from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions
19+
from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER
1920
from mcp.shared.auth import OAuthMetadata
2021

2122

@@ -55,7 +56,7 @@ def cors_middleware(
5556
app=request_response(handler),
5657
allow_origins="*",
5758
allow_methods=allow_methods,
58-
allow_headers=["mcp-protocol-version"],
59+
allow_headers=[MCP_PROTOCOL_VERSION_HEADER],
5960
)
6061
return cors_app
6162

src/mcp/server/streamable_http.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
from starlette.types import Receive, Scope, Send
2626

2727
from mcp.shared.message import ServerMessageMetadata, SessionMessage
28+
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
2829
from mcp.types import (
30+
DEFAULT_NEGOTIATED_VERSION,
2931
INTERNAL_ERROR,
3032
INVALID_PARAMS,
3133
INVALID_REQUEST,
@@ -45,6 +47,7 @@
4547

4648
# Header names
4749
MCP_SESSION_ID_HEADER = "mcp-session-id"
50+
MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version"
4851
LAST_EVENT_ID_HEADER = "last-event-id"
4952

5053
# Content types
@@ -293,7 +296,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
293296
has_json, has_sse = self._check_accept_headers(request)
294297
if not (has_json and has_sse):
295298
response = self._create_error_response(
296-
("Not Acceptable: Client must accept both application/json and " "text/event-stream"),
299+
("Not Acceptable: Client must accept both application/json and text/event-stream"),
297300
HTTPStatus.NOT_ACCEPTABLE,
298301
)
299302
await response(scope, receive, send)
@@ -353,8 +356,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
353356
)
354357
await response(scope, receive, send)
355358
return
356-
# For non-initialization requests, validate the session
357-
elif not await self._validate_session(request, send):
359+
elif not await self._validate_request_headers(request, send):
358360
return
359361

360362
# For notifications and responses only, return 202 Accepted
@@ -513,8 +515,9 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
513515
await response(request.scope, request.receive, send)
514516
return
515517

516-
if not await self._validate_session(request, send):
518+
if not await self._validate_request_headers(request, send):
517519
return
520+
518521
# Handle resumability: check for Last-Event-ID header
519522
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
520523
await self._replay_events(last_event_id, request, send)
@@ -593,7 +596,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None:
593596
await response(request.scope, request.receive, send)
594597
return
595598

596-
if not await self._validate_session(request, send):
599+
if not await self._validate_request_headers(request, send):
597600
return
598601

599602
await self._terminate_session()
@@ -653,6 +656,13 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non
653656
)
654657
await response(request.scope, request.receive, send)
655658

659+
async def _validate_request_headers(self, request: Request, send: Send) -> bool:
660+
if not await self._validate_session(request, send):
661+
return False
662+
if not await self._validate_protocol_version(request, send):
663+
return False
664+
return True
665+
656666
async def _validate_session(self, request: Request, send: Send) -> bool:
657667
"""Validate the session ID in the request."""
658668
if not self.mcp_session_id:
@@ -682,6 +692,28 @@ async def _validate_session(self, request: Request, send: Send) -> bool:
682692

683693
return True
684694

695+
async def _validate_protocol_version(self, request: Request, send: Send) -> bool:
696+
"""Validate the protocol version header in the request."""
697+
# Get the protocol version from the request headers
698+
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)
699+
700+
# If no protocol version provided, assume default version
701+
if protocol_version is None:
702+
protocol_version = DEFAULT_NEGOTIATED_VERSION
703+
704+
# Check if the protocol version is supported
705+
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS:
706+
supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS)
707+
response = self._create_error_response(
708+
f"Bad Request: Unsupported protocol version: {protocol_version}. "
709+
+ f"Supported versions: {supported_versions}",
710+
HTTPStatus.BAD_REQUEST,
711+
)
712+
await response(request.scope, request.receive, send)
713+
return False
714+
715+
return True
716+
685717
async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
686718
"""
687719
Replays events that would have been sent after the specified event ID.

src/mcp/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@
2424

2525
LATEST_PROTOCOL_VERSION = "2025-03-26"
2626

27+
"""
28+
The default negotiated version of the Model Context Protocol when no version is specified.
29+
We need this to satisfy the MCP specification, which requires the server to assume a
30+
specific version if none is provided by the client. See section "Protocol Version Header" at
31+
https://modelcontextprotocol.io/specification
32+
"""
33+
DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
34+
2735
ProgressToken = str | int
2836
Cursor = str
2937
Role = Literal["user", "assistant"]

0 commit comments

Comments
 (0)