Skip to content
Closed
1 change: 1 addition & 0 deletions src/google/adk/auth/auth_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class OAuth2Auth(BaseModelWithConfig):
refresh_token: Optional[str] = None
expires_at: Optional[int] = None
expires_in: Optional[int] = None
audience: Optional[str] = None


class ServiceAccountCredential(BaseModelWithConfig):
Expand Down
9 changes: 8 additions & 1 deletion src/google/adk/auth/auth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,16 @@ def generate_auth_uri(
scope=" ".join(scopes),
redirect_uri=auth_credential.oauth2.redirect_uri,
)
params = {
"access_type": "offline",
"prompt": "consent",
}
if auth_credential.oauth2.audience:
params["audience"] = auth_credential.oauth2.audience
uri, state = client.create_authorization_url(
url=authorization_endpoint, access_type="offline", prompt="consent"
url=authorization_endpoint, **params
)

exchanged_auth_credential = auth_credential.model_copy(deep=True)
exchanged_auth_credential.oauth2.auth_uri = uri
exchanged_auth_credential.oauth2.state = state
Expand Down
24 changes: 23 additions & 1 deletion tests/unittests/auth/test_auth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ def __init__(
self.state = state

def create_authorization_url(self, url, **kwargs):
return f"{url}?client_id={self.client_id}&scope={self.scope}", "mock_state"
params = f"client_id={self.client_id}&scope={self.scope}"
if kwargs.get("audience"):
params += f"&audience={kwargs.get('audience')}"
return f"{url}?{params}", "mock_state"

def fetch_token(
self,
Expand Down Expand Up @@ -225,8 +228,27 @@ def test_generate_auth_uri_oauth2(self, auth_config):
"https://example.com/oauth2/authorize"
)
assert "client_id=mock_client_id" in result.oauth2.auth_uri
assert "audience" not in result.oauth2.auth_uri
assert result.oauth2.state == "mock_state"

@patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session)
def test_generate_auth_uri_with_audience_and_prompt(
self, openid_auth_scheme, oauth2_credentials
):
"""Test generating an auth URI with audience and prompt."""
oauth2_credentials.oauth2.audience = "test_audience"
exchanged = oauth2_credentials.model_copy(deep=True)

config = AuthConfig(
auth_scheme=openid_auth_scheme,
raw_auth_credential=oauth2_credentials,
exchanged_auth_credential=exchanged,
)
handler = AuthHandler(config)
result = handler.generate_auth_uri()

assert "audience=test_audience" in result.oauth2.auth_uri

@patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session)
def test_generate_auth_uri_openid(
self, openid_auth_scheme, oauth2_credentials
Expand Down
Loading