Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

### Bug Fixes

* Added `timeout` to `requests.post()`/`requests.get()` calls in `oauth.py` that previously had no timeout, which could cause indefinite hangs when the OAuth endpoint is unreachable during token refresh ([#1338](https://github.com/databricks/databricks-sdk-py/issues/1338)).

### Documentation

### Internal Changes
Expand Down
5 changes: 3 additions & 2 deletions databricks/sdk/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

logger = logging.getLogger("databricks.sdk")

_DEFAULT_HTTP_TIMEOUT_SECONDS = 60


def _fix_host_if_needed(host: Optional[str]) -> Optional[str]:
if not host:
Expand Down Expand Up @@ -96,8 +98,7 @@ def __init__(
)
self._session.mount("https://", http_adapter)

# Default to 60 seconds
self._http_timeout_seconds = http_timeout_seconds or 60
self._http_timeout_seconds = http_timeout_seconds or _DEFAULT_HTTP_TIMEOUT_SECONDS

self._error_parser = _Parser(
extra_error_customizers=extra_error_customizers,
Expand Down
6 changes: 4 additions & 2 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import requests

from . import useragent
from ._base_client import _fix_host_if_needed
from ._base_client import _DEFAULT_HTTP_TIMEOUT_SECONDS, _fix_host_if_needed
from .client_types import ClientType, HostType
from .clock import Clock, RealClock
from .credentials_provider import (CredentialsStrategy, DefaultCredentials,
Expand Down Expand Up @@ -552,7 +552,9 @@ def oidc_endpoints(self) -> Optional[OidcEndpoints]:
if not self.host:
return None
if self.is_azure and self.azure_client_id:
return get_azure_entra_id_workspace_endpoints(self.host)
return get_azure_entra_id_workspace_endpoints(
self.host, timeout=self.http_timeout_seconds or _DEFAULT_HTTP_TIMEOUT_SECONDS
)
return self.databricks_oidc_endpoints

def debug_string(self) -> str:
Expand Down
14 changes: 12 additions & 2 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from databricks.sdk.oauth import get_azure_entra_id_workspace_endpoints

from . import azure, oauth, oidc, oidc_token_supplier
from ._base_client import _DEFAULT_HTTP_TIMEOUT_SECONDS
from .client_types import ClientType

CredentialsProvider = Callable[[], Dict[str, str]]
Expand Down Expand Up @@ -203,6 +204,7 @@ def get_notebook_pat_token() -> Optional[str]:
host=cfg.host,
scopes=cfg.get_scopes_as_string(),
authorization_details=cfg.authorization_details,
http_timeout_seconds=cfg.http_timeout_seconds or _DEFAULT_HTTP_TIMEOUT_SECONDS,
)

def inner() -> Dict[str, str]:
Expand Down Expand Up @@ -232,6 +234,7 @@ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
use_header=True,
disable_async=cfg.disable_async_token_refresh,
authorization_details=cfg.authorization_details,
http_timeout_seconds=cfg.http_timeout_seconds or _DEFAULT_HTTP_TIMEOUT_SECONDS,
)

def inner() -> Dict[str, str]:
Expand All @@ -258,7 +261,9 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
elif cfg.azure_client_id:
client_id = cfg.azure_client_id
client_secret = cfg.azure_client_secret
oidc_endpoints = get_azure_entra_id_workspace_endpoints(cfg.host)
oidc_endpoints = get_azure_entra_id_workspace_endpoints(
cfg.host, timeout=cfg.http_timeout_seconds or _DEFAULT_HTTP_TIMEOUT_SECONDS
)
if not client_id:
client_id = "databricks-cli"
oidc_endpoints = cfg.databricks_oidc_endpoints
Expand Down Expand Up @@ -348,6 +353,7 @@ def token_source_for(resource: str) -> oauth.TokenSource:
disable_async=cfg.disable_async_token_refresh,
scopes=cfg.get_scopes_as_string(),
authorization_details=cfg.authorization_details,
http_timeout_seconds=cfg.http_timeout_seconds or _DEFAULT_HTTP_TIMEOUT_SECONDS,
)

_ensure_host_present(cfg, token_source_for)
Expand Down Expand Up @@ -470,6 +476,7 @@ def token_source_for(audience: str) -> oauth.TokenSource:
use_params=True,
disable_async=cfg.disable_async_token_refresh,
authorization_details=cfg.authorization_details,
http_timeout_seconds=cfg.http_timeout_seconds or _DEFAULT_HTTP_TIMEOUT_SECONDS,
)

def refreshed_headers() -> Dict[str, str]:
Expand Down Expand Up @@ -532,7 +539,9 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
aad_endpoint = cfg.arm_environment.active_directory_endpoint
if not cfg.azure_tenant_id:
# detect Azure AD Tenant ID if it's not specified directly
token_endpoint = get_azure_entra_id_workspace_endpoints(cfg.host).token_endpoint
token_endpoint = get_azure_entra_id_workspace_endpoints(
cfg.host, timeout=cfg.http_timeout_seconds or _DEFAULT_HTTP_TIMEOUT_SECONDS
).token_endpoint
cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, "").split("/")[0]

inner = oauth.ClientCredentials(
Expand All @@ -548,6 +557,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
disable_async=cfg.disable_async_token_refresh,
scopes=cfg.get_scopes_as_string(),
authorization_details=cfg.authorization_details,
http_timeout_seconds=cfg.http_timeout_seconds or _DEFAULT_HTTP_TIMEOUT_SECONDS,
)

def refreshed_headers() -> Dict[str, str]:
Expand Down
14 changes: 10 additions & 4 deletions databricks/sdk/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import requests
import requests.auth

from ._base_client import _BaseClient, _fix_host_if_needed
from ._base_client import _BaseClient, _DEFAULT_HTTP_TIMEOUT_SECONDS, _fix_host_if_needed
from .environments import Cloud

# Error code for PKCE flow in Azure Active Directory, that gets additional retry.
Expand Down Expand Up @@ -194,6 +194,7 @@ def retrieve_token(
use_params=False,
use_header=False,
headers=None,
timeout=_DEFAULT_HTTP_TIMEOUT_SECONDS,
) -> Token:
logger.debug(f"Retrieving token for {client_id}")
if use_params:
Expand All @@ -206,7 +207,7 @@ def retrieve_token(
auth = requests.auth.HTTPBasicAuth(client_id, client_secret)
else:
auth = IgnoreNetrcAuth()
resp = requests.post(token_url, params, auth=auth, headers=headers)
resp = requests.post(token_url, params, auth=auth, headers=headers, timeout=timeout)
if not resp.ok:
if resp.headers["Content-Type"].startswith("application/json"):
err = resp.json()
Expand Down Expand Up @@ -533,16 +534,18 @@ def get_unified_endpoints(host: str, account_id: str, client: _BaseClient = _Bas

def get_azure_entra_id_workspace_endpoints(
host: str,
timeout: int = _DEFAULT_HTTP_TIMEOUT_SECONDS,
) -> Optional[OidcEndpoints]:
"""
Get the Azure Entra ID endpoints for a given workspace. Can only be used when authenticating to Azure Databricks
using an application registered in Azure Entra ID.
:param host: The Databricks workspace host.
:param timeout: HTTP request timeout in seconds.
:return: The OIDC endpoints for the workspace's Azure Entra ID tenant.
"""
# In Azure, this workspace endpoint redirects to the Entra ID authorization endpoint
host = _fix_host_if_needed(host)
res = requests.get(f"{host}/oidc/oauth2/v2.0/authorize", allow_redirects=False)
res = requests.get(f"{host}/oidc/oauth2/v2.0/authorize", allow_redirects=False, timeout=timeout)
real_auth_url = res.headers.get("location")
if not real_auth_url:
return None
Expand Down Expand Up @@ -848,6 +851,7 @@ class ClientCredentials(Refreshable):
use_header: bool = False
disable_async: bool = True
authorization_details: str = None
http_timeout_seconds: int = _DEFAULT_HTTP_TIMEOUT_SECONDS

def __post_init__(self):
super().__init__(disable_async=self.disable_async)
Expand All @@ -868,6 +872,7 @@ def refresh(self) -> Token:
params,
use_params=self.use_params,
use_header=self.use_header,
timeout=self.http_timeout_seconds,
)


Expand All @@ -894,6 +899,7 @@ class PATOAuthTokenExchange(Refreshable):
scopes: str
authorization_details: str = None
disable_async: bool = True
http_timeout_seconds: int = _DEFAULT_HTTP_TIMEOUT_SECONDS

def __post_init__(self):
super().__init__(disable_async=self.disable_async)
Expand All @@ -910,7 +916,7 @@ def refresh(self) -> Token:
if self.authorization_details:
params["authorization_details"] = self.authorization_details

resp = requests.post(token_exchange_url, params)
resp = requests.post(token_exchange_url, params, timeout=self.http_timeout_seconds)
if not resp.ok:
if resp.headers["Content-Type"].startswith("application/json"):
err = resp.json()
Expand Down
Loading