Skip to content

Commit e9a688d

Browse files
Merge branch 'main' into feat/tasks
2 parents 802d080 + 5489e8b commit e9a688d

File tree

4 files changed

+268
-12
lines changed

4 files changed

+268
-12
lines changed

src/mcp/client/auth/oauth2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
506506

507507
# Step 3: Apply scope selection strategy
508508
self.context.client_metadata.scope = get_client_metadata_scopes(
509-
www_auth_resource_metadata_url,
509+
extract_scope_from_www_auth(response),
510510
self.context.protected_resource_metadata,
511511
self.context.oauth_metadata,
512512
)

src/mcp/client/session_group.py

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,23 @@
1111
import contextlib
1212
import logging
1313
from collections.abc import Callable
14+
from dataclasses import dataclass
1415
from datetime import timedelta
1516
from types import TracebackType
16-
from typing import Any, TypeAlias
17+
from typing import Any, TypeAlias, overload
1718

1819
import anyio
1920
from pydantic import BaseModel
20-
from typing_extensions import Self
21+
from typing_extensions import Self, deprecated
2122

2223
import mcp
2324
from mcp import types
25+
from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
2426
from mcp.client.sse import sse_client
2527
from mcp.client.stdio import StdioServerParameters
2628
from mcp.client.streamable_http import streamablehttp_client
2729
from mcp.shared.exceptions import McpError
30+
from mcp.shared.session import ProgressFnT
2831

2932

3033
class SseServerParameters(BaseModel):
@@ -65,6 +68,21 @@ class StreamableHttpParameters(BaseModel):
6568
ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters
6669

6770

71+
# Use dataclass instead of pydantic BaseModel
72+
# because pydantic BaseModel cannot handle Protocol fields.
73+
@dataclass
74+
class ClientSessionParameters:
75+
"""Parameters for establishing a client session to an MCP server."""
76+
77+
read_timeout_seconds: timedelta | None = None
78+
sampling_callback: SamplingFnT | None = None
79+
elicitation_callback: ElicitationFnT | None = None
80+
list_roots_callback: ListRootsFnT | None = None
81+
logging_callback: LoggingFnT | None = None
82+
message_handler: MessageHandlerFnT | None = None
83+
client_info: types.Implementation | None = None
84+
85+
6886
class ClientSessionGroup:
6987
"""Client for managing connections to multiple MCP servers.
7088
@@ -172,11 +190,49 @@ def tools(self) -> dict[str, types.Tool]:
172190
"""Returns the tools as a dictionary of names to tools."""
173191
return self._tools
174192

175-
async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult:
193+
@overload
194+
async def call_tool(
195+
self,
196+
name: str,
197+
arguments: dict[str, Any],
198+
read_timeout_seconds: timedelta | None = None,
199+
progress_callback: ProgressFnT | None = None,
200+
*,
201+
meta: dict[str, Any] | None = None,
202+
) -> types.CallToolResult: ...
203+
204+
@overload
205+
@deprecated("The 'args' parameter is deprecated. Use 'arguments' instead.")
206+
async def call_tool(
207+
self,
208+
name: str,
209+
*,
210+
args: dict[str, Any],
211+
read_timeout_seconds: timedelta | None = None,
212+
progress_callback: ProgressFnT | None = None,
213+
meta: dict[str, Any] | None = None,
214+
) -> types.CallToolResult: ...
215+
216+
async def call_tool(
217+
self,
218+
name: str,
219+
arguments: dict[str, Any] | None = None,
220+
read_timeout_seconds: timedelta | None = None,
221+
progress_callback: ProgressFnT | None = None,
222+
*,
223+
meta: dict[str, Any] | None = None,
224+
args: dict[str, Any] | None = None,
225+
) -> types.CallToolResult:
176226
"""Executes a tool given its name and arguments."""
177227
session = self._tool_to_session[name]
178228
session_tool_name = self.tools[name].name
179-
return await session.call_tool(session_tool_name, args)
229+
return await session.call_tool(
230+
session_tool_name,
231+
arguments if args is None else args,
232+
read_timeout_seconds=read_timeout_seconds,
233+
progress_callback=progress_callback,
234+
meta=meta,
235+
)
180236

181237
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
182238
"""Disconnects from a single MCP server."""
@@ -225,13 +281,16 @@ async def connect_with_session(
225281
async def connect_to_server(
226282
self,
227283
server_params: ServerParameters,
284+
session_params: ClientSessionParameters | None = None,
228285
) -> mcp.ClientSession:
229286
"""Connects to a single MCP server."""
230-
server_info, session = await self._establish_session(server_params)
287+
server_info, session = await self._establish_session(server_params, session_params or ClientSessionParameters())
231288
return await self.connect_with_session(server_info, session)
232289

233290
async def _establish_session(
234-
self, server_params: ServerParameters
291+
self,
292+
server_params: ServerParameters,
293+
session_params: ClientSessionParameters,
235294
) -> tuple[types.Implementation, mcp.ClientSession]:
236295
"""Establish a client session to an MCP server."""
237296

@@ -259,7 +318,20 @@ async def _establish_session(
259318
)
260319
read, write, _ = await session_stack.enter_async_context(client)
261320

262-
session = await session_stack.enter_async_context(mcp.ClientSession(read, write))
321+
session = await session_stack.enter_async_context(
322+
mcp.ClientSession(
323+
read,
324+
write,
325+
read_timeout_seconds=session_params.read_timeout_seconds,
326+
sampling_callback=session_params.sampling_callback,
327+
elicitation_callback=session_params.elicitation_callback,
328+
list_roots_callback=session_params.list_roots_callback,
329+
logging_callback=session_params.logging_callback,
330+
message_handler=session_params.message_handler,
331+
client_info=session_params.client_info,
332+
)
333+
)
334+
263335
result = await session.initialize()
264336

265337
# Session successfully initialized.
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
"""
2+
Regression test for issue #1630: OAuth2 scope incorrectly set to resource_metadata URL.
3+
4+
This test verifies that when a 401 response contains both resource_metadata and scope
5+
in the WWW-Authenticate header, the actual scope is used (not the resource_metadata URL).
6+
"""
7+
8+
from unittest import mock
9+
10+
import httpx
11+
import pytest
12+
from pydantic import AnyUrl
13+
14+
from mcp.client.auth import OAuthClientProvider
15+
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
16+
17+
18+
class MockTokenStorage:
19+
"""Mock token storage for testing."""
20+
21+
def __init__(self) -> None:
22+
self._tokens: OAuthToken | None = None
23+
self._client_info: OAuthClientInformationFull | None = None
24+
25+
async def get_tokens(self) -> OAuthToken | None:
26+
return self._tokens # pragma: no cover
27+
28+
async def set_tokens(self, tokens: OAuthToken) -> None:
29+
self._tokens = tokens
30+
31+
async def get_client_info(self) -> OAuthClientInformationFull | None:
32+
return self._client_info # pragma: no cover
33+
34+
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
35+
self._client_info = client_info # pragma: no cover
36+
37+
38+
@pytest.mark.anyio
39+
async def test_401_uses_www_auth_scope_not_resource_metadata_url():
40+
"""
41+
Regression test for #1630: Ensure scope is extracted from WWW-Authenticate header,
42+
not the resource_metadata URL.
43+
44+
When a 401 response contains:
45+
WWW-Authenticate: Bearer resource_metadata="https://...", scope="read write"
46+
47+
The client should use "read write" as the scope, NOT the resource_metadata URL.
48+
"""
49+
50+
async def redirect_handler(url: str) -> None:
51+
pass # pragma: no cover
52+
53+
async def callback_handler() -> tuple[str, str | None]:
54+
return "test_auth_code", "test_state" # pragma: no cover
55+
56+
client_metadata = OAuthClientMetadata(
57+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
58+
client_name="Test Client",
59+
)
60+
61+
provider = OAuthClientProvider(
62+
server_url="https://api.example.com/mcp",
63+
client_metadata=client_metadata,
64+
storage=MockTokenStorage(),
65+
redirect_handler=redirect_handler,
66+
callback_handler=callback_handler,
67+
)
68+
69+
provider.context.current_tokens = None
70+
provider.context.token_expiry_time = None
71+
provider._initialized = True
72+
73+
# Pre-set client info to skip DCR
74+
provider.context.client_info = OAuthClientInformationFull(
75+
client_id="test_client",
76+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
77+
)
78+
79+
test_request = httpx.Request("GET", "https://api.example.com/mcp")
80+
auth_flow = provider.async_auth_flow(test_request)
81+
82+
# First request (no auth header yet)
83+
await auth_flow.__anext__()
84+
85+
# 401 response with BOTH resource_metadata URL and scope in WWW-Authenticate
86+
# This is the key: the bug would use the URL as scope instead of "read write"
87+
resource_metadata_url = "https://api.example.com/.well-known/oauth-protected-resource"
88+
expected_scope = "read write"
89+
90+
response_401 = httpx.Response(
91+
401,
92+
headers={"WWW-Authenticate": (f'Bearer resource_metadata="{resource_metadata_url}", scope="{expected_scope}"')},
93+
request=test_request,
94+
)
95+
96+
# Send 401, expect PRM discovery request
97+
prm_request = await auth_flow.asend(response_401)
98+
assert ".well-known/oauth-protected-resource" in str(prm_request.url)
99+
100+
# PRM response with scopes_supported (these should be overridden by WWW-Auth scope)
101+
prm_response = httpx.Response(
102+
200,
103+
content=(
104+
b'{"resource": "https://api.example.com/mcp", '
105+
b'"authorization_servers": ["https://auth.example.com"], '
106+
b'"scopes_supported": ["fallback:scope1", "fallback:scope2"]}'
107+
),
108+
request=prm_request,
109+
)
110+
111+
# Send PRM response, expect OAuth metadata discovery
112+
oauth_metadata_request = await auth_flow.asend(prm_response)
113+
assert ".well-known/oauth-authorization-server" in str(oauth_metadata_request.url)
114+
115+
# OAuth metadata response
116+
oauth_metadata_response = httpx.Response(
117+
200,
118+
content=(
119+
b'{"issuer": "https://auth.example.com", '
120+
b'"authorization_endpoint": "https://auth.example.com/authorize", '
121+
b'"token_endpoint": "https://auth.example.com/token"}'
122+
),
123+
request=oauth_metadata_request,
124+
)
125+
126+
# Mock authorization to skip interactive flow
127+
provider._perform_authorization_code_grant = mock.AsyncMock(return_value=("test_auth_code", "test_code_verifier"))
128+
129+
# Send OAuth metadata response, expect token request
130+
token_request = await auth_flow.asend(oauth_metadata_response)
131+
assert "token" in str(token_request.url)
132+
133+
# NOW CHECK: The scope should be the WWW-Authenticate scope, NOT the URL
134+
# This is where the bug manifested - scope was set to resource_metadata_url
135+
actual_scope = provider.context.client_metadata.scope
136+
137+
# This assertion would FAIL on main (scope would be the URL)
138+
# but PASS on the fix branch (scope is "read write")
139+
assert actual_scope == expected_scope, (
140+
f"Expected scope to be '{expected_scope}' from WWW-Authenticate header, "
141+
f"but got '{actual_scope}'. "
142+
f"If scope is '{resource_metadata_url}', the bug from #1630 is present."
143+
)
144+
145+
# Verify it's definitely not the URL (explicit check for the bug)
146+
assert actual_scope != resource_metadata_url, (
147+
f"BUG #1630: Scope was incorrectly set to resource_metadata URL '{resource_metadata_url}' "
148+
f"instead of the actual scope '{expected_scope}'"
149+
)
150+
151+
# Complete the flow to properly release the lock
152+
token_response = httpx.Response(
153+
200,
154+
content=b'{"access_token": "test_token", "token_type": "Bearer", "expires_in": 3600}',
155+
request=token_request,
156+
)
157+
158+
final_request = await auth_flow.asend(token_response)
159+
assert final_request.headers["Authorization"] == "Bearer test_token"
160+
161+
# Finish the flow
162+
final_response = httpx.Response(200, request=final_request)
163+
try:
164+
await auth_flow.asend(final_response)
165+
except StopAsyncIteration:
166+
pass

tests/client/test_session_group.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55

66
import mcp
77
from mcp import types
8-
from mcp.client.session_group import ClientSessionGroup, SseServerParameters, StreamableHttpParameters
8+
from mcp.client.session_group import (
9+
ClientSessionGroup,
10+
ClientSessionParameters,
11+
SseServerParameters,
12+
StreamableHttpParameters,
13+
)
914
from mcp.client.stdio import StdioServerParameters
1015
from mcp.shared.exceptions import McpError
1116

@@ -62,7 +67,7 @@ def hook(name: str, server_info: types.Implementation) -> str: # pragma: no cov
6267
# --- Test Execution ---
6368
result = await mcp_session_group.call_tool(
6469
name="server1-my_tool",
65-
args={
70+
arguments={
6671
"name": "value1",
6772
"args": {},
6873
},
@@ -73,6 +78,9 @@ def hook(name: str, server_info: types.Implementation) -> str: # pragma: no cov
7378
mock_session.call_tool.assert_called_once_with(
7479
"my_tool",
7580
{"name": "value1", "args": {}},
81+
read_timeout_seconds=None,
82+
progress_callback=None,
83+
meta=None,
7684
)
7785

7886
async def test_connect_to_server(self, mock_exit_stack: contextlib.AsyncExitStack):
@@ -329,7 +337,7 @@ async def test_establish_session_parameterized(
329337
(
330338
returned_server_info,
331339
returned_session,
332-
) = await group._establish_session(server_params_instance)
340+
) = await group._establish_session(server_params_instance, ClientSessionParameters())
333341

334342
# --- Assertions ---
335343
# 1. Assert the correct specific client function was called
@@ -357,7 +365,17 @@ async def test_establish_session_parameterized(
357365
mock_client_cm_instance.__aenter__.assert_awaited_once()
358366

359367
# 2. Assert ClientSession was called correctly
360-
mock_ClientSession_class.assert_called_once_with(mock_read_stream, mock_write_stream)
368+
mock_ClientSession_class.assert_called_once_with(
369+
mock_read_stream,
370+
mock_write_stream,
371+
read_timeout_seconds=None,
372+
sampling_callback=None,
373+
elicitation_callback=None,
374+
list_roots_callback=None,
375+
logging_callback=None,
376+
message_handler=None,
377+
client_info=None,
378+
)
361379
mock_raw_session_cm.__aenter__.assert_awaited_once()
362380
mock_entered_session.initialize.assert_awaited_once()
363381

0 commit comments

Comments
 (0)