Skip to content
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ test:
pytest -m 'not integration and not benchmark' --cov=databricks --cov-report html tests

integration:
# TODO: Remove dual-run once experimental_is_unified_host flag is removed and unified mode becomes default
@echo "Running integration tests in unified mode..."
DATABRICKS_EXPERIMENTAL_IS_UNIFIED_HOST=true pytest -n auto -m 'integration and not benchmark' --reruns 4 --dist loadgroup --cov=databricks --cov-append --cov-report html tests
@echo "Running integration tests in legacy mode..."
pytest -n auto -m 'integration and not benchmark' --reruns 4 --dist loadgroup --cov=databricks --cov-report html tests

benchmark:
Expand Down
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Release v0.84.0

### New Features and Improvements
* Add support for legacy Profiles in Unified Mode. It is now possible to use any host in Unified Mode.

### Security

Expand Down
1 change: 1 addition & 0 deletions databricks/sdk/__init__.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

104 changes: 72 additions & 32 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def __get__(self, cfg: "Config", owner):
return cfg._inner.get(self.name, None)

def __set__(self, cfg: "Config", value: any):
cfg._inner[self.name] = self.transform(value)
if value is None:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug fix. Before, this:

cfg.account_id = None

would actually set

cfg.account_id = "None"

cfg._inner.pop(self.name, None)
else:
cfg._inner[self.name] = self.transform(value)

def __repr__(self) -> str:
return f"<ConfigAttribute '{self.name}' {self.transform.__name__}>"
Expand Down Expand Up @@ -264,14 +267,20 @@ def __init__(
self.databricks_environment = kwargs["databricks_environment"]
del kwargs["databricks_environment"]
self._clock = clock if clock is not None else RealClock()

try:
self._set_inner_config(kwargs)
self._load_from_env()
self._known_file_config_loader()
self._fix_host_if_needed()
# Resolve the legacy profile based on configuration just before validation.
self._resolve_legacy_profile()
self._validate()
self.init_auth()
self._init_product(product, product_version)
# Extract the workspace ID for legacy profiles. This is extracted from an API call.
if not self.workspace_id and not self.account_id and self.experimental_is_unified_host:
self.workspace_id = self._fetch_workspace_id()
except ValueError as e:
message = self.wrap_debug_info(str(e))
raise ValueError(message) from e
Expand Down Expand Up @@ -335,33 +344,25 @@ def _get_azure_environment_name(self) -> str:
@property
def environment(self) -> DatabricksEnvironment:
"""Returns the environment based on configuration."""
if self.databricks_environment:
return self.databricks_environment
if not self.host and self.azure_workspace_resource_id:
azure_env = self._get_azure_environment_name()
for environment in ALL_ENVS:
if environment.cloud != Cloud.AZURE:
continue
if environment.azure_environment.name != azure_env:
continue
if environment.dns_zone.startswith(".dev") or environment.dns_zone.startswith(".staging"):
continue
return environment
return get_environment_for_hostname(self.host)
if not self.experimental_is_unified_host:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now resolve this once during new Config, not lazily. This ensures that we don't read the host after Config is created.

# Preserve old behavior by default.
# TODO: Remove this when making the unified mode the default.
self._resolve_environment()
return self.databricks_environment

@property
def is_azure(self) -> bool:
if self.azure_workspace_resource_id:
return True
return self.environment.cloud == Cloud.AZURE
return self.environment is not None and self.environment.cloud == Cloud.AZURE
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unified Hosts will return false on all 3 clouds.


@property
def is_gcp(self) -> bool:
return self.environment.cloud == Cloud.GCP
return self.environment is not None and self.environment.cloud == Cloud.GCP

@property
def is_aws(self) -> bool:
return self.environment.cloud == Cloud.AWS
return self.environment is not None and self.environment.cloud == Cloud.AWS

@property
def host_type(self) -> HostType:
Expand All @@ -384,7 +385,9 @@ def host_type(self) -> HostType:

@property
def client_type(self) -> ClientType:
"""Determine the type of client configuration.
"""
[Deprecated] Deprecated. Use host_type instead. Some hosts can support both account and workspace clients.
Determine the type of client configuration.

This is separate from host_type. For example, a unified host can support both
workspace and account client types.
Expand All @@ -403,25 +406,24 @@ def client_type(self) -> ClientType:
return ClientType.WORKSPACE

if host_type == HostType.UNIFIED:
if not self.account_id:
raise ValueError("Unified host requires account_id to be set")
if self.workspace_id:
return ClientType.WORKSPACE
return ClientType.ACCOUNT
if self.account_id:
return ClientType.ACCOUNT
# Legacy workspace hosts don't have a workspace_id until AFTER the auth is resolved.
return ClientType.WORKSPACE

# Default to workspace for backward compatibility
return ClientType.WORKSPACE

@property
def is_account_client(self) -> bool:
"""[Deprecated] Use host_type or client_type instead.
"""[Deprecated] Use host_type instead.

Determines if this is an account client based on the host URL.
Determines if this config is compatible with an account client based on the host URL and account_id.
"""
if self.experimental_is_unified_host:
raise ValueError(
"is_account_client cannot be used with unified hosts; use host_type or client_type instead"
)
return self.account_id
if not self.host:
return False
return self.host.startswith("https://accounts.") or self.host.startswith("https://accounts-dod.")
Expand Down Expand Up @@ -480,7 +482,7 @@ def oidc_endpoints(self) -> Optional[OidcEndpoints]:
# Handle unified hosts
if self.host_type == HostType.UNIFIED:
if not self.account_id:
raise ValueError("Unified host requires account_id to be set for OAuth endpoints")
return get_workspace_endpoints(self.host)
return get_unified_endpoints(self.host, self.account_id)

# Handle traditional account hosts
Expand Down Expand Up @@ -525,12 +527,9 @@ def sql_http_path(self) -> Optional[str]:
return None
if self.cluster_id and self.warehouse_id:
raise ValueError("cannot have both cluster_id and warehouse_id")
headers = self.authenticate()
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extracted to a separate func

headers["User-Agent"] = f"{self.user_agent} sdk-feature/sql-http-path"
if self.cluster_id:
response = requests.get(f"{self.host}/api/2.0/preview/scim/v2/Me", headers=headers)
# get workspace ID from the response header
workspace_id = response.headers.get("x-databricks-org-id")
# Reuse cached workspace_id or fetch it
workspace_id = self.workspace_id or self._fetch_workspace_id()
return f"sql/protocolv1/o/{workspace_id}/{self.cluster_id}"
if self.warehouse_id:
return f"/sql/1.0/warehouses/{self.warehouse_id}"
Expand Down Expand Up @@ -721,3 +720,44 @@ def copy(self):
def deep_copy(self):
"""Creates a deep copy of the config object."""
return copy.deepcopy(self)

# The code below is used to support legacy hosts.
def _resolve_environment(self):
"""Resolve the environment based on configuration."""
if self.databricks_environment:
return
if not self.host and self.azure_workspace_resource_id:
azure_env = self._get_azure_environment_name()
for environment in ALL_ENVS:
if environment.cloud != Cloud.AZURE:
continue
if environment.azure_environment.name != azure_env:
continue
if environment.dns_zone.startswith(".dev") or environment.dns_zone.startswith(".staging"):
continue
self.databricks_environment = environment
return
self.databricks_environment = get_environment_for_hostname(self.host)

def _resolve_legacy_profile(self):
"""Resolve the legacy profile based on configuration."""

# This only applies to the unified mode.
# TODO: Remove this when making the unified mode the default.
if not self.experimental_is_unified_host:
return
# New Profiles always have an account ID.
if not self.account_id:
self._resolve_environment()

if self.host and (self.host.startswith("https://accounts.") or self.host.startswith("https://accounts-dod.")):
self._resolve_environment()

def _fetch_workspace_id(self) -> Optional[str]:
"""Fetch the workspace ID from the host."""
headers = self.authenticate()
headers["User-Agent"] = f"{self.user_agent} sdk-feature/sql-http-path"
response = requests.get(f"{self.host}/api/2.0/preview/scim/v2/Me", headers=headers)
response.raise_for_status()
# get workspace ID from the response header
return response.headers.get("x-databricks-org-id")
8 changes: 4 additions & 4 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,9 @@ def _oidc_credentials_provider(

# Determine the audience for token exchange
audience = cfg.token_audience
if audience is None and cfg.client_type == ClientType.ACCOUNT:
if audience is None and cfg.account_id:
audience = cfg.account_id
if audience is None and cfg.client_type != ClientType.ACCOUNT:
if audience is None and not cfg.account_id:
audience = cfg.oidc_endpoints.token_endpoint

# Try to get an OIDC token. If no supplier returns a token, we cannot use this authentication mode.
Expand Down Expand Up @@ -590,7 +590,7 @@ def token() -> oauth.Token:
def refreshed_headers() -> Dict[str, str]:
credentials.refresh(request)
headers = {"Authorization": f"Bearer {credentials.token}"}
if cfg.client_type == ClientType.ACCOUNT:
if cfg.account_id:
gcp_credentials.refresh(request)
headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token
return headers
Expand Down Expand Up @@ -631,7 +631,7 @@ def token() -> oauth.Token:
def refreshed_headers() -> Dict[str, str]:
id_creds.refresh(request)
headers = {"Authorization": f"Bearer {id_creds.token}"}
if cfg.client_type == ClientType.ACCOUNT:
if cfg.account_id:
gcp_impersonated_credentials.refresh(request)
headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token
return headers
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def a(env_or_skip) -> AccountClient:
_load_debug_env_if_runs_from_ide("account")
env_or_skip("CLOUD_ENV")
account_client = AccountClient()
if not account_client.config.is_account_client:
if not account_client.config.account_id:
pytest.skip("not Databricks Account client")
return account_client

Expand All @@ -75,7 +75,7 @@ def ucacct(env_or_skip) -> AccountClient:
_load_debug_env_if_runs_from_ide("ucacct")
env_or_skip("CLOUD_ENV")
account_client = AccountClient()
if not account_client.config.is_account_client:
if not account_client.config.account_id:
pytest.skip("not Databricks Account client")
if "TEST_METASTORE_ID" not in os.environ:
pytest.skip("not in Unity Catalog Workspace test env")
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import io
import json
import os
import re
import shutil
import subprocess
Expand Down Expand Up @@ -261,6 +262,9 @@ def test_wif_workspace(ucacct, env_or_skip, random):
permissions=[iam.WorkspacePermission.ADMIN],
)

# Clean env var
os.environ.pop("DATABRICKS_ACCOUNT_ID", None)

ws = WorkspaceClient(
host=workspace_url,
client_id=sp.application_id,
Expand Down
Loading
Loading