Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,14 @@ def token() -> oauth.Token:
def refreshed_headers() -> Dict[str, str]:
credentials.refresh(request)
headers = {"Authorization": f"Bearer {credentials.token}"}
if cfg.client_type == ClientType.ACCOUNT:
# GCP SA Access token is only required for specific account level operations.
# It is possible that a user does not have persomissions to mint the GCP SA access token,
# but this is not a blocking error at this point.
try:
gcp_credentials.refresh(request)
headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token
except Exception as e:
logger.warning(f"Failed to refresh GCP credentials: {e}")
return headers

return OAuthCredentialsProvider(refreshed_headers, token)
Expand Down Expand Up @@ -633,9 +638,14 @@ def token() -> oauth.Token:
def refreshed_headers() -> Dict[str, str]:
id_creds.refresh(request)
headers = {"Authorization": f"Bearer {id_creds.token}"}
if cfg.client_type == ClientType.ACCOUNT:
# GCP SA Access token is only required for specific account level operations.
# It is possible that a user does not have persomissions to mint the GCP SA access token,
# but this is not a blocking error at this point.
try:
gcp_impersonated_credentials.refresh(request)
headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token
except Exception as e:
logger.warning(f"Failed to refresh GCP impersonated credentials: {e}")
return headers

return OAuthCredentialsProvider(refreshed_headers, token)
Expand Down
124 changes: 124 additions & 0 deletions tests/test_credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,63 @@ def test_google_credentials_with_cloud_agnostic_host(self, mocker):
assert "Authorization" in headers
assert headers["Authorization"] == "Bearer test-google-token"

def test_google_credentials_includes_sa_token_on_success(self, mocker):
"""Test that google_credentials includes GCP SA access token when refresh succeeds."""
mock_cfg = Mock()
mock_cfg.host = "https://api.databricks.com"
mock_cfg.google_credentials = '{"type": "service_account", "project_id": "test"}'
mock_cfg.disable_async_token_refresh = True

mock_id_credentials = Mock()
mock_id_credentials.token = "test-id-token"

mock_sa_credentials = Mock()
mock_sa_credentials.token = "test-sa-token"

mocker.patch(
"databricks.sdk.credentials_provider.service_account.IDTokenCredentials.from_service_account_info",
return_value=mock_id_credentials,
)
mocker.patch(
"databricks.sdk.credentials_provider.service_account.Credentials.from_service_account_info",
return_value=mock_sa_credentials,
)

provider = credentials_provider.google_credentials(mock_cfg)
headers = provider()
assert headers["Authorization"] == "Bearer test-id-token"
assert headers["X-Databricks-GCP-SA-Access-Token"] == "test-sa-token"

def test_google_credentials_warns_on_sa_token_failure(self, mocker):
"""Test that google_credentials logs warning and omits SA token when refresh fails."""
mock_cfg = Mock()
mock_cfg.host = "https://api.databricks.com"
mock_cfg.google_credentials = '{"type": "service_account", "project_id": "test"}'
mock_cfg.disable_async_token_refresh = True

mock_id_credentials = Mock()
mock_id_credentials.token = "test-id-token"

mock_sa_credentials = Mock()
mock_sa_credentials.refresh.side_effect = Exception("permission denied")

mocker.patch(
"databricks.sdk.credentials_provider.service_account.IDTokenCredentials.from_service_account_info",
return_value=mock_id_credentials,
)
mocker.patch(
"databricks.sdk.credentials_provider.service_account.Credentials.from_service_account_info",
return_value=mock_sa_credentials,
)

provider = credentials_provider.google_credentials(mock_cfg)
mock_logger = mocker.patch("databricks.sdk.credentials_provider.logger")
headers = provider()

assert headers["Authorization"] == "Bearer test-id-token"
assert "X-Databricks-GCP-SA-Access-Token" not in headers
mock_logger.warning.assert_called_once()

def test_google_id_with_cloud_agnostic_host(self, mocker):
"""Test that google_id works with cloud-agnostic hosts after removing is_gcp check."""
# Mock Config with cloud-agnostic host
Expand Down Expand Up @@ -642,6 +699,73 @@ def test_google_id_with_cloud_agnostic_host(self, mocker):
assert "Authorization" in headers
assert headers["Authorization"] == "Bearer test-google-id-token"

def test_google_id_includes_sa_token_on_success(self, mocker):
"""Test that google_id includes GCP SA access token when refresh succeeds."""
mock_cfg = Mock()
mock_cfg.host = "https://api.databricks.com"
mock_cfg.google_service_account = "test-sa@project.iam.gserviceaccount.com"

mock_source_credentials = Mock()
mocker.patch(
"databricks.sdk.credentials_provider.google.auth.default",
return_value=(mock_source_credentials, "test-project"),
)

mock_id_creds = Mock()
mock_id_creds.token = "test-google-id-token"

mock_gcp_creds = Mock()
mock_gcp_creds.token = "test-gcp-sa-token"

mocker.patch(
"databricks.sdk.credentials_provider.impersonated_credentials.Credentials",
return_value=mock_gcp_creds,
)
mocker.patch(
"databricks.sdk.credentials_provider.impersonated_credentials.IDTokenCredentials",
return_value=mock_id_creds,
)

provider = credentials_provider.google_id(mock_cfg)
headers = provider()
assert headers["Authorization"] == "Bearer test-google-id-token"
assert headers["X-Databricks-GCP-SA-Access-Token"] == "test-gcp-sa-token"

def test_google_id_warns_on_sa_token_failure(self, mocker):
"""Test that google_id logs warning and omits SA token when refresh fails."""
mock_cfg = Mock()
mock_cfg.host = "https://api.databricks.com"
mock_cfg.google_service_account = "test-sa@project.iam.gserviceaccount.com"

mock_source_credentials = Mock()
mocker.patch(
"databricks.sdk.credentials_provider.google.auth.default",
return_value=(mock_source_credentials, "test-project"),
)

mock_id_creds = Mock()
mock_id_creds.token = "test-google-id-token"

mock_gcp_creds = Mock()
mock_gcp_creds.refresh.side_effect = Exception("permission denied")

mocker.patch(
"databricks.sdk.credentials_provider.impersonated_credentials.Credentials",
return_value=mock_gcp_creds,
)
mocker.patch(
"databricks.sdk.credentials_provider.impersonated_credentials.IDTokenCredentials",
return_value=mock_id_creds,
)

provider = credentials_provider.google_id(mock_cfg)
mock_logger = mocker.patch("databricks.sdk.credentials_provider.logger")
headers = provider()

assert headers["Authorization"] == "Bearer test-google-id-token"
assert "X-Databricks-GCP-SA-Access-Token" not in headers
mock_logger.warning.assert_called_once()

def test_github_oidc_azure_with_cloud_agnostic_host(self, mocker):
"""Test that github_oidc_azure works with cloud-agnostic hosts after removing is_azure check."""
# Set up GitHub Actions environment
Expand Down
Loading