Skip to content

Commit 684eafd

Browse files
committed
Fix client_credentials providers to set client_info in _initialize
The base class _initialize() loads client_info from storage, which overwrites any value set in the constructor. Move client_info setup to _initialize override so it's properly set after tokens are loaded. Also update tests to call _initialize() before checking client_info.
1 parent b007b28 commit 684eafd

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

src/mcp/client/auth/extensions/client_credentials.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def __init__(
6666
scope=scopes,
6767
)
6868
super().__init__(server_url, client_metadata, storage, None, None, 300.0)
69-
# Set client_info directly - no need for dynamic registration
70-
self.context.client_info = OAuthClientInformationFull(
69+
# Store client_info to be set during _initialize - no dynamic registration needed
70+
self._fixed_client_info = OAuthClientInformationFull(
7171
redirect_uris=None,
7272
client_id=client_id,
7373
client_secret=client_secret,
@@ -76,6 +76,12 @@ def __init__(
7676
scope=scopes,
7777
)
7878

79+
async def _initialize(self) -> None:
80+
"""Load stored tokens and set pre-configured client_info."""
81+
self.context.current_tokens = await self.context.storage.get_tokens()
82+
self.context.client_info = self._fixed_client_info
83+
self._initialized = True
84+
7985
async def _perform_authorization(self) -> httpx.Request:
8086
"""Perform client_credentials authorization."""
8187
return await self._exchange_token_client_credentials()
@@ -275,15 +281,21 @@ def __init__(
275281
)
276282
super().__init__(server_url, client_metadata, storage, None, None, 300.0)
277283
self._assertion_provider = assertion_provider
278-
# Set client_info directly - no need for dynamic registration
279-
self.context.client_info = OAuthClientInformationFull(
284+
# Store client_info to be set during _initialize - no dynamic registration needed
285+
self._fixed_client_info = OAuthClientInformationFull(
280286
redirect_uris=None,
281287
client_id=client_id,
282288
grant_types=["client_credentials"],
283289
token_endpoint_auth_method="private_key_jwt",
284290
scope=scopes,
285291
)
286292

293+
async def _initialize(self) -> None:
294+
"""Load stored tokens and set pre-configured client_info."""
295+
self.context.current_tokens = await self.context.storage.get_tokens()
296+
self.context.client_info = self._fixed_client_info
297+
self._initialized = True
298+
287299
async def _perform_authorization(self) -> httpx.Request:
288300
"""Perform client_credentials authorization with private_key_jwt."""
289301
return await self._exchange_token_client_credentials()

tests/client/auth/extensions/test_client_credentials.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,22 +177,27 @@ async def test_token_exchange_request_jwt(self, rfc7523_oauth_provider: RFC7523O
177177
class TestClientCredentialsOAuthProvider:
178178
"""Test ClientCredentialsOAuthProvider."""
179179

180-
def test_init_sets_client_info(self, mock_storage: MockTokenStorage):
181-
"""Test that constructor sets client_info directly."""
180+
@pytest.mark.anyio
181+
async def test_init_sets_client_info(self, mock_storage: MockTokenStorage):
182+
"""Test that _initialize sets client_info."""
182183
provider = ClientCredentialsOAuthProvider(
183184
server_url="https://api.example.com",
184185
storage=mock_storage,
185186
client_id="test-client-id",
186187
client_secret="test-client-secret",
187188
)
188189

190+
# client_info is set during _initialize
191+
await provider._initialize()
192+
189193
assert provider.context.client_info is not None
190194
assert provider.context.client_info.client_id == "test-client-id"
191195
assert provider.context.client_info.client_secret == "test-client-secret"
192196
assert provider.context.client_info.grant_types == ["client_credentials"]
193197
assert provider.context.client_info.token_endpoint_auth_method == "client_secret_basic"
194198

195-
def test_init_with_scopes(self, mock_storage: MockTokenStorage):
199+
@pytest.mark.anyio
200+
async def test_init_with_scopes(self, mock_storage: MockTokenStorage):
196201
"""Test that constructor accepts scopes."""
197202
provider = ClientCredentialsOAuthProvider(
198203
server_url="https://api.example.com",
@@ -202,9 +207,11 @@ def test_init_with_scopes(self, mock_storage: MockTokenStorage):
202207
scopes="read write",
203208
)
204209

210+
await provider._initialize()
205211
assert provider.context.client_info.scope == "read write"
206212

207-
def test_init_with_client_secret_post(self, mock_storage: MockTokenStorage):
213+
@pytest.mark.anyio
214+
async def test_init_with_client_secret_post(self, mock_storage: MockTokenStorage):
208215
"""Test that constructor accepts client_secret_post auth method."""
209216
provider = ClientCredentialsOAuthProvider(
210217
server_url="https://api.example.com",
@@ -214,6 +221,7 @@ def test_init_with_client_secret_post(self, mock_storage: MockTokenStorage):
214221
token_endpoint_auth_method="client_secret_post",
215222
)
216223

224+
await provider._initialize()
217225
assert provider.context.client_info.token_endpoint_auth_method == "client_secret_post"
218226

219227
@pytest.mark.anyio
@@ -270,8 +278,9 @@ async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorag
270278
class TestPrivateKeyJWTOAuthProvider:
271279
"""Test PrivateKeyJWTOAuthProvider."""
272280

273-
def test_init_sets_client_info(self, mock_storage: MockTokenStorage):
274-
"""Test that constructor sets client_info directly."""
281+
@pytest.mark.anyio
282+
async def test_init_sets_client_info(self, mock_storage: MockTokenStorage):
283+
"""Test that _initialize sets client_info."""
275284

276285
async def mock_assertion_provider(audience: str) -> str:
277286
return "mock-jwt"
@@ -283,6 +292,9 @@ async def mock_assertion_provider(audience: str) -> str:
283292
assertion_provider=mock_assertion_provider,
284293
)
285294

295+
# client_info is set during _initialize
296+
await provider._initialize()
297+
286298
assert provider.context.client_info is not None
287299
assert provider.context.client_info.client_id == "test-client-id"
288300
assert provider.context.client_info.grant_types == ["client_credentials"]

0 commit comments

Comments
 (0)