Skip to content

Commit 5c6cdcd

Browse files
seanzhougooglecopybara-github
authored andcommitted
feat: Support Oauth2 client credentials grant type
PiperOrigin-RevId: 815813477
1 parent 46d73be commit 5c6cdcd

File tree

5 files changed

+381
-22
lines changed

5 files changed

+381
-22
lines changed

src/google/adk/auth/credential_manager.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525
from .auth_credential import AuthCredentialTypes
2626
from .auth_schemes import AuthSchemeType
2727
from .auth_schemes import ExtendedOAuth2
28+
from .auth_schemes import OpenIdConnectWithConfig
2829
from .auth_tool import AuthConfig
2930
from .exchanger.base_credential_exchanger import BaseCredentialExchanger
3031
from .exchanger.credential_exchanger_registry import CredentialExchangerRegistry
3132
from .oauth2_discovery import OAuth2DiscoveryManager
32-
from .refresher.base_credential_refresher import BaseCredentialRefresher
3333
from .refresher.credential_refresher_registry import CredentialRefresherRegistry
3434

3535
logger = logging.getLogger("google_adk." + __name__)
@@ -85,8 +85,17 @@ def __init__(
8585

8686
# Register default exchangers and refreshers
8787
# TODO: support service account credential exchanger
88+
from .exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger
8889
from .refresher.oauth2_credential_refresher import OAuth2CredentialRefresher
8990

91+
oauth2_exchanger = OAuth2CredentialExchanger()
92+
self._exchanger_registry.register(
93+
AuthCredentialTypes.OAUTH2, oauth2_exchanger
94+
)
95+
self._exchanger_registry.register(
96+
AuthCredentialTypes.OPEN_ID_CONNECT, oauth2_exchanger
97+
)
98+
9099
oauth2_refresher = OAuth2CredentialRefresher()
91100
self._refresher_registry.register(
92101
AuthCredentialTypes.OAUTH2, oauth2_refresher
@@ -134,9 +143,14 @@ async def get_auth_credential(
134143
credential = await self._load_from_auth_response(callback_context)
135144
was_from_auth_response = True
136145

137-
# Step 5: If still no credential available, return None
146+
# Step 5: If still no credential available, check if client credentials
138147
if not credential:
139-
return None
148+
# For client credentials flow, use raw credentials directly
149+
if self._is_client_credentials_flow():
150+
credential = self._auth_config.raw_auth_credential
151+
else:
152+
# For authorization code flow, return None to trigger user authorization
153+
return None
140154

141155
# Step 6: Exchange credential if needed (e.g., service account to access token)
142156
credential, was_exchanged = await self._exchange_credential(credential)
@@ -328,3 +342,26 @@ def _missing_oauth_info(self) -> bool:
328342
and not flows.authorizationCode.tokenUrl
329343
)
330344
return False
345+
346+
def _is_client_credentials_flow(self) -> bool:
347+
"""Check if the auth scheme uses client credentials flow.
348+
349+
Supports both OAuth2 and OIDC schemes.
350+
351+
Returns:
352+
True if using client credentials flow, False otherwise.
353+
"""
354+
auth_scheme = self._auth_config.auth_scheme
355+
356+
# Check OAuth2 schemes
357+
if isinstance(auth_scheme, OAuth2) and auth_scheme.flows:
358+
return auth_scheme.flows.clientCredentials is not None
359+
360+
# Check OIDC schemes
361+
if isinstance(auth_scheme, OpenIdConnectWithConfig):
362+
return (
363+
auth_scheme.grant_types_supported is not None
364+
and "client_credentials" in auth_scheme.grant_types_supported
365+
)
366+
367+
return False

src/google/adk/auth/exchanger/oauth2_credential_exchanger.py

Lines changed: 96 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
import logging
2020
from typing import Optional
2121

22+
from fastapi.openapi.models import OAuth2
2223
from google.adk.auth.auth_credential import AuthCredential
2324
from google.adk.auth.auth_schemes import AuthScheme
2425
from google.adk.auth.auth_schemes import OAuthGrantType
26+
from google.adk.auth.auth_schemes import OpenIdConnectWithConfig
2527
from google.adk.auth.oauth2_credential_util import create_oauth2_session
2628
from google.adk.auth.oauth2_credential_util import update_credential_with_tokens
2729
from google.adk.utils.feature_decorator import experimental
@@ -81,9 +83,100 @@ async def exchange(
8183
if auth_credential.oauth2 and auth_credential.oauth2.access_token:
8284
return auth_credential
8385

86+
# Determine grant type from auth_scheme
87+
grant_type = self._determine_grant_type(auth_scheme)
88+
89+
if grant_type == OAuthGrantType.CLIENT_CREDENTIALS:
90+
return await self._exchange_client_credentials(
91+
auth_credential, auth_scheme
92+
)
93+
elif grant_type == OAuthGrantType.AUTHORIZATION_CODE:
94+
return await self._exchange_authorization_code(
95+
auth_credential, auth_scheme
96+
)
97+
else:
98+
logger.warning("Unsupported OAuth2 grant type: %s", grant_type)
99+
return auth_credential
100+
101+
def _determine_grant_type(
102+
self, auth_scheme: AuthScheme
103+
) -> Optional[OAuthGrantType]:
104+
"""Determine the OAuth2 grant type from the auth scheme.
105+
106+
Args:
107+
auth_scheme: The OAuth2 authentication scheme.
108+
109+
Returns:
110+
The OAuth2 grant type or None if cannot be determined.
111+
"""
112+
if isinstance(auth_scheme, OAuth2) and auth_scheme.flows:
113+
return OAuthGrantType.from_flow(auth_scheme.flows)
114+
elif isinstance(auth_scheme, OpenIdConnectWithConfig):
115+
# Check supported grant types for OIDC
116+
if (
117+
auth_scheme.grant_types_supported
118+
and "client_credentials" in auth_scheme.grant_types_supported
119+
):
120+
return OAuthGrantType.CLIENT_CREDENTIALS
121+
else:
122+
# Default to authorization code if client credentials not supported
123+
return OAuthGrantType.AUTHORIZATION_CODE
124+
125+
return None
126+
127+
async def _exchange_client_credentials(
128+
self,
129+
auth_credential: AuthCredential,
130+
auth_scheme: AuthScheme,
131+
) -> AuthCredential:
132+
"""Exchange client credentials for access token.
133+
134+
Args:
135+
auth_credential: The OAuth2 credential to exchange.
136+
auth_scheme: The OAuth2 authentication scheme.
137+
138+
Returns:
139+
The credential with access token.
140+
"""
84141
client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential)
85142
if not client:
86-
logger.warning("Could not create OAuth2 session for token exchange")
143+
logger.warning(
144+
"Could not create OAuth2 session for client credentials exchange"
145+
)
146+
return auth_credential
147+
148+
try:
149+
tokens = client.fetch_token(
150+
token_endpoint,
151+
grant_type=OAuthGrantType.CLIENT_CREDENTIALS,
152+
)
153+
update_credential_with_tokens(auth_credential, tokens)
154+
logger.debug("Successfully exchanged client credentials for access token")
155+
except Exception as e:
156+
logger.error("Failed to exchange client credentials: %s", e)
157+
return auth_credential
158+
159+
return auth_credential
160+
161+
async def _exchange_authorization_code(
162+
self,
163+
auth_credential: AuthCredential,
164+
auth_scheme: AuthScheme,
165+
) -> AuthCredential:
166+
"""Exchange authorization code for access token.
167+
168+
Args:
169+
auth_credential: The OAuth2 credential to exchange.
170+
auth_scheme: The OAuth2 authentication scheme.
171+
172+
Returns:
173+
The credential with access token.
174+
"""
175+
client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential)
176+
if not client:
177+
logger.warning(
178+
"Could not create OAuth2 session for authorization code exchange"
179+
)
87180
return auth_credential
88181

89182
try:
@@ -94,11 +187,9 @@ async def exchange(
94187
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
95188
)
96189
update_credential_with_tokens(auth_credential, tokens)
97-
logger.debug("Successfully exchanged OAuth2 tokens")
190+
logger.debug("Successfully exchanged authorization code for access token")
98191
except Exception as e:
99-
# TODO reconsider whether we should raise errors in this case
100-
logger.error("Failed to exchange OAuth2 tokens: %s", e)
101-
# Return original credential on failure
192+
logger.error("Failed to exchange authorization code: %s", e)
102193
return auth_credential
103194

104195
return auth_credential

src/google/adk/auth/oauth2_credential_util.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,15 @@
1818
from typing import Optional
1919
from typing import Tuple
2020

21+
from authlib.integrations.requests_client import OAuth2Session
22+
from authlib.oauth2.rfc6749 import OAuth2Token
2123
from fastapi.openapi.models import OAuth2
2224

2325
from ..utils.feature_decorator import experimental
2426
from .auth_credential import AuthCredential
2527
from .auth_schemes import AuthScheme
2628
from .auth_schemes import OpenIdConnectWithConfig
2729

28-
try:
29-
from authlib.integrations.requests_client import OAuth2Session
30-
from authlib.oauth2.rfc6749 import OAuth2Token
31-
32-
AUTHLIB_AVAILABLE = True
33-
except ImportError:
34-
AUTHLIB_AVAILABLE = False
35-
36-
3730
logger = logging.getLogger("google_adk." + __name__)
3831

3932

@@ -53,18 +46,34 @@ def create_oauth2_session(
5346
"""
5447
if isinstance(auth_scheme, OpenIdConnectWithConfig):
5548
if not hasattr(auth_scheme, "token_endpoint"):
49+
logger.warning("OpenIdConnect scheme missing token_endpoint")
5650
return None, None
5751
token_endpoint = auth_scheme.token_endpoint
58-
scopes = auth_scheme.scopes
52+
scopes = auth_scheme.scopes or []
5953
elif isinstance(auth_scheme, OAuth2):
54+
# Support both authorization code and client credentials flows
6055
if (
61-
not auth_scheme.flows.authorizationCode
62-
or not auth_scheme.flows.authorizationCode.tokenUrl
56+
auth_scheme.flows.authorizationCode
57+
and auth_scheme.flows.authorizationCode.tokenUrl
58+
):
59+
token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl
60+
scopes = list(auth_scheme.flows.authorizationCode.scopes.keys())
61+
elif (
62+
auth_scheme.flows.clientCredentials
63+
and auth_scheme.flows.clientCredentials.tokenUrl
6364
):
65+
token_endpoint = auth_scheme.flows.clientCredentials.tokenUrl
66+
scopes = list(auth_scheme.flows.clientCredentials.scopes.keys())
67+
else:
68+
logger.warning(
69+
"OAuth2 scheme missing required flow configuration. Expected either"
70+
" authorizationCode.tokenUrl or clientCredentials.tokenUrl. Auth"
71+
" scheme: %s",
72+
auth_scheme,
73+
)
6474
return None, None
65-
token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl
66-
scopes = list(auth_scheme.flows.authorizationCode.scopes.keys())
6775
else:
76+
logger.warning(f"Unsupported auth_scheme type: {type(auth_scheme)}")
6877
return None, None
6978

7079
if (

tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from unittest.mock import patch
1818

1919
from authlib.oauth2.rfc6749 import OAuth2Token
20+
from fastapi.openapi.models import OAuth2
21+
from fastapi.openapi.models import OAuthFlowClientCredentials
22+
from fastapi.openapi.models import OAuthFlows
2023
from google.adk.auth.auth_credential import AuthCredential
2124
from google.adk.auth.auth_credential import AuthCredentialTypes
2225
from google.adk.auth.auth_credential import OAuth2Auth
@@ -218,3 +221,116 @@ async def test_exchange_authlib_not_available(self):
218221
# Should return original credential when authlib is not available
219222
assert result == credential
220223
assert result.oauth2.access_token is None
224+
225+
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
226+
@pytest.mark.asyncio
227+
async def test_exchange_client_credentials_success(self, mock_oauth2_session):
228+
"""Test successful client credentials exchange."""
229+
# Setup mock
230+
mock_client = Mock()
231+
mock_oauth2_session.return_value = mock_client
232+
mock_tokens = OAuth2Token({
233+
"access_token": "client_access_token",
234+
"expires_at": int(time.time()) + 3600,
235+
"expires_in": 3600,
236+
})
237+
mock_client.fetch_token.return_value = mock_tokens
238+
239+
# Create OAuth2 scheme with client credentials flow
240+
flows = OAuthFlows(
241+
clientCredentials=OAuthFlowClientCredentials(
242+
tokenUrl="https://example.com/token",
243+
scopes={"read": "Read access", "write": "Write access"},
244+
)
245+
)
246+
scheme = OAuth2(flows=flows)
247+
248+
credential = AuthCredential(
249+
auth_type=AuthCredentialTypes.OAUTH2,
250+
oauth2=OAuth2Auth(
251+
client_id="test_client_id",
252+
client_secret="test_client_secret",
253+
),
254+
)
255+
256+
exchanger = OAuth2CredentialExchanger()
257+
result = await exchanger.exchange(credential, scheme)
258+
259+
# Verify client credentials exchange was successful
260+
assert result.oauth2.access_token == "client_access_token"
261+
mock_client.fetch_token.assert_called_once_with(
262+
"https://example.com/token",
263+
grant_type="client_credentials",
264+
)
265+
266+
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
267+
@pytest.mark.asyncio
268+
async def test_exchange_client_credentials_failure(self, mock_oauth2_session):
269+
"""Test client credentials exchange failure."""
270+
# Setup mock to raise exception during fetch_token
271+
mock_client = Mock()
272+
mock_oauth2_session.return_value = mock_client
273+
mock_client.fetch_token.side_effect = Exception(
274+
"Client credentials fetch failed"
275+
)
276+
277+
# Create OAuth2 scheme with client credentials flow
278+
flows = OAuthFlows(
279+
clientCredentials=OAuthFlowClientCredentials(
280+
tokenUrl="https://example.com/token", scopes={"read": "Read access"}
281+
)
282+
)
283+
scheme = OAuth2(flows=flows)
284+
285+
credential = AuthCredential(
286+
auth_type=AuthCredentialTypes.OAUTH2,
287+
oauth2=OAuth2Auth(
288+
client_id="test_client_id",
289+
client_secret="test_client_secret",
290+
),
291+
)
292+
293+
exchanger = OAuth2CredentialExchanger()
294+
result = await exchanger.exchange(credential, scheme)
295+
296+
# Should return original credential when client credentials exchange fails
297+
assert result == credential
298+
assert result.oauth2.access_token is None
299+
mock_client.fetch_token.assert_called_once()
300+
301+
@pytest.mark.asyncio
302+
async def test_determine_grant_type_client_credentials(self):
303+
"""Test grant type determination for client credentials."""
304+
flows = OAuthFlows(
305+
clientCredentials=OAuthFlowClientCredentials(
306+
tokenUrl="https://example.com/token", scopes={"read": "Read access"}
307+
)
308+
)
309+
scheme = OAuth2(flows=flows)
310+
311+
exchanger = OAuth2CredentialExchanger()
312+
grant_type = exchanger._determine_grant_type(scheme)
313+
314+
from google.adk.auth.auth_schemes import OAuthGrantType
315+
316+
assert grant_type == OAuthGrantType.CLIENT_CREDENTIALS
317+
318+
@pytest.mark.asyncio
319+
async def test_determine_grant_type_openid_connect(self):
320+
"""Test grant type determination for OpenID Connect (defaults to auth code)."""
321+
scheme = OpenIdConnectWithConfig(
322+
type_="openIdConnect",
323+
openId_connect_url=(
324+
"https://example.com/.well-known/openid_configuration"
325+
),
326+
authorization_endpoint="https://example.com/auth",
327+
token_endpoint="https://example.com/token",
328+
scopes=["openid"],
329+
)
330+
331+
exchanger = OAuth2CredentialExchanger()
332+
grant_type = exchanger._determine_grant_type(scheme)
333+
334+
from google.adk.auth.auth_schemes import OAuthGrantType
335+
336+
assert grant_type == OAuthGrantType.AUTHORIZATION_CODE

0 commit comments

Comments
 (0)