Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ class Config:
# is hit first will stop the retry loop.
experimental_files_ext_cloud_api_max_retries: int = 3

# Whether to enable the storage proxy for file operations.
# When enabled, the SDK will probe the storage proxy and use it if available.
experimental_files_ext_enable_storage_proxy: bool = False

def __init__(
self,
*,
Expand Down
246 changes: 197 additions & 49 deletions databricks/sdk/mixins/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,13 +819,106 @@ def build_resumable_upload_url(self, path: str, session_token: str) -> _Presigne
headers[h["name"]] = h["value"]
return _PresignedUrl(url=url, headers=headers)

def build_abort_url(self, path: str, session_token: str, expire_time: str) -> _PresignedUrl:
"""Fetches a presigned URL for aborting a multipart upload."""
body = {
"path": path,
"session_token": session_token,
"expire_time": expire_time,
}
response = self._api.do(
"POST",
url=f"{self._hostname}/api/2.0/fs/create-abort-upload-url",
headers={"Content-Type": "application/json"},
body=body,
)
abort_node = response["abort_upload_url"]
headers: dict[str, str] = {"Content-Type": "application/octet-stream"}
for h in abort_node.get("headers", []):
headers[h["name"]] = h["value"]
return _PresignedUrl(url=abort_node["url"], headers=headers)


class _StorageProxyRequestBuilder:
"""Builds direct upload requests to the storage proxy.

Skips the presigned URL coordination APIs entirely. Instead, constructs
URLs that point directly at the proxy endpoint with query parameters for
session token, upload type, and part number. The proxy handles cloud
storage interaction internally.
"""

def __init__(self, hostname: str):
self._hostname = hostname

def build_upload_part_urls(
self, path: str, session_token: str, start_part_number: int, count: int, expire_time: str
) -> list[_PresignedUrl]:
"""Builds URLs for uploading multipart parts directly to the storage proxy."""
escaped = _escape_multi_segment_path_parameter(path)
base = f"{self._hostname}/api/2.0/fs/files{escaped}"
results = []
for i in range(count):
part_number = start_part_number + i
query = parse.urlencode(
{
"session_token": session_token,
"upload_type": "multipart",
"part_number": part_number,
}
)
results.append(
_PresignedUrl(
url=f"{base}?{query}",
headers={"Content-Type": "application/octet-stream"},
part_number=part_number,
)
)
return results

def build_resumable_upload_url(self, path: str, session_token: str) -> _PresignedUrl:
"""Builds a URL for resumable upload directly to the storage proxy."""
escaped = _escape_multi_segment_path_parameter(path)
base = f"{self._hostname}/api/2.0/fs/files{escaped}"
query = parse.urlencode(
{
"session_token": session_token,
"upload_type": "resumable",
}
)
return _PresignedUrl(
url=f"{base}?{query}",
headers={"Content-Type": "application/octet-stream"},
)

def build_abort_url(self, path: str, session_token: str, expire_time: str) -> _PresignedUrl:
"""Builds a URL for aborting an upload directly on the storage proxy."""
escaped = _escape_multi_segment_path_parameter(path)
base = f"{self._hostname}/api/2.0/fs/files{escaped}"
query = parse.urlencode(
{
"action": "abort-upload",
"session_token": session_token,
}
)
return _PresignedUrl(
url=f"{base}?{query}",
headers={"Content-Type": "application/json"},
)


class FilesExt(files.FilesAPI):
__doc__ = files.FilesAPI.__doc__

# note that these error codes are retryable only for idempotent operations
_RETRYABLE_STATUS_CODES: list[int] = [408, 429, 502, 503, 504]

# Storage-proxy hostname for data plane file operations.
_STORAGE_PROXY_HOSTNAME: str = "http://storage-proxy.databricks.com"

# Timeout in seconds for the storage-proxy health check probe.
_STORAGE_PROXY_PROBE_TIMEOUT: float = 3.0

@dataclass(frozen=True)
class _UploadContext:
target_path: str
Expand All @@ -851,15 +944,90 @@ def __init__(self, api_client, config: Config):
self._multipart_upload_read_ahead_bytes = 1
self._cached_cloud_provider_session: Optional[requests.Session] = None
self._cloud_provider_session_lock = Lock()
self._dp_hostname_available: Optional[bool] = None
self._cached_storage_proxy_session: Optional[requests.Session] = None
self._storage_proxy_session_lock = Lock()

def _get_hostname(self) -> str:
"""Returns the hostname for file operations.

Currently always returns the workspace hostname. This method is an
extension point for future data-plane routing (e.g. storage proxy).
When storage proxy is enabled, probes it on first call and caches the
result. Returns the proxy hostname if available, otherwise the workspace
hostname.
"""
if self._config.experimental_files_ext_enable_storage_proxy:
if self._dp_hostname_available is None:
self._dp_hostname_available = self._probe_storage_proxy()
if self._dp_hostname_available:
_LOG.info("Storage proxy is available, will use it for file operations.")
else:
_LOG.info("Storage proxy is not available, will use presigned URLs.")
if self._dp_hostname_available:
return self._STORAGE_PROXY_HOSTNAME
return self._config.host

def _create_request_builder(self):
"""Creates the appropriate request builder for file operations.

Returns a storage proxy builder when the proxy is available,
otherwise a presigned URL builder.
"""
hostname = self._get_hostname()
if self._config.experimental_files_ext_enable_storage_proxy and self._dp_hostname_available:
return _StorageProxyRequestBuilder(hostname)
return _PresignedUrlRequestBuilder(self._api, hostname)

def _probe_storage_proxy(self) -> bool:
"""Probes the storage proxy to check if it is reachable.

Makes a GET request to the probe endpoint using SDK auth. The result
is cached in self._dp_hostname_available after the first call.
"""
proxy_host = self._STORAGE_PROXY_HOSTNAME
probe_url = f"{proxy_host}/api/2.0/fs/files/DatabricksInternal/Probes/ping"
try:
headers = self._config.authenticate()
session = self._cloud_provider_session()
response = session.request(
"GET",
probe_url,
headers=headers,
timeout=self._STORAGE_PROXY_PROBE_TIMEOUT,
)
return response.status_code == 200
except Exception:
return False

def _create_storage_proxy_session(self) -> requests.Session:
"""Returns an HTTP session with SDK auth for storage-proxy requests.

Unlike _cloud_provider_session (which has no auth), this session
includes the SDK authentication callback so that every request to the
storage proxy carries valid credentials. The session is created on
first call and cached for reuse.
"""
with self._storage_proxy_session_lock:
if self._cached_storage_proxy_session is not None:
return self._cached_storage_proxy_session
session = requests.Session()
config = self._config

def authenticate(r: requests.PreparedRequest) -> requests.PreparedRequest:
auth_headers = config.authenticate()
r.headers.update(auth_headers)
return r

session.auth = authenticate
http_adapter = requests.adapters.HTTPAdapter(
config.max_connection_pools or 20,
config.max_connections_per_pool or 20,
pool_block=True,
)
session.mount("https://", http_adapter)
session.mount("http://", http_adapter)
self._cached_storage_proxy_session = session
return session

def download(
self,
file_path: str,
Expand Down Expand Up @@ -1749,7 +1917,7 @@ def _do_upload_one_part(
part_content: BinaryIO,
is_first_part: bool = False,
) -> str:
builder = _PresignedUrlRequestBuilder(self._api, self._get_hostname())
builder = self._create_request_builder()
retry_count = 0

# Try to upload the part, retrying if the upload URL expires.
Expand Down Expand Up @@ -1819,7 +1987,7 @@ def _perform_multipart_upload(
Performs multipart upload using presigned URLs on AWS and Azure:
https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpuoverview.html
"""
builder = _PresignedUrlRequestBuilder(self._api, self._get_hostname())
builder = self._create_request_builder()
current_part_number = 1
etags: dict = {}

Expand Down Expand Up @@ -2015,8 +2183,6 @@ def _perform_resumable_upload(
"""
Performs resumable upload on GCP: https://cloud.google.com/storage/docs/performing-resumable-uploads
"""
hostname = self._get_hostname()

# Session URI we're using expires after a week

# Why are we buffering the current chunk?
Expand All @@ -2040,7 +2206,7 @@ def _perform_resumable_upload(
# On the contrary, in multipart upload we can decide to complete upload *after*
# last chunk has been sent.

builder = _PresignedUrlRequestBuilder(self._api, hostname)
builder = self._create_request_builder()

try:
presigned = builder.build_resumable_upload_url(ctx.target_path, session_token)
Expand Down Expand Up @@ -2238,33 +2404,14 @@ def _get_download_url_expire_time(self) -> str:

def _abort_multipart_upload(self, ctx: _UploadContext, session_token: str) -> None:
"""Aborts ongoing multipart upload session to clean up incomplete file."""
hostname = self._get_hostname()
body: dict = {
"path": ctx.target_path,
"session_token": session_token,
"expire_time": self._get_upload_url_expire_time(),
}

headers = {"Content-Type": "application/json"}

# Method _api.do() takes care of retrying and will raise an exception in case of failure.
abort_url_response = self._api.do(
"POST", url=f"{hostname}/api/2.0/fs/create-abort-upload-url", headers=headers, body=body
)

abort_upload_url_node = abort_url_response["abort_upload_url"]
abort_url = abort_upload_url_node["url"]
required_headers = abort_upload_url_node.get("headers", [])

headers: dict = {"Content-Type": "application/octet-stream"}
for h in required_headers:
headers[h["name"]] = h["value"]
builder = self._create_request_builder()
abort_info = builder.build_abort_url(ctx.target_path, session_token, self._get_upload_url_expire_time())

def perform() -> requests.Response:
return self._cloud_provider_session().request(
"DELETE",
abort_url,
headers=headers,
abort_info.url,
headers=abort_info.headers,
data=b"",
timeout=self._config.files_ext_network_transfer_inactivity_timeout_seconds,
)
Expand Down Expand Up @@ -2292,28 +2439,29 @@ def perform() -> requests.Response:
raise ValueError(abort_response)

def _cloud_provider_session(self) -> requests.Session:
"""Returns a session which does not inherit auth headers from BaseClient session.
"""Returns a session for cloud storage operations.

The session is created on first call and cached for reuse. This follows the same
caching pattern as _BaseClient._session, which is also created once and never
invalidated.
When the storage proxy is in use, returns an authenticated session
because proxy uploads require SDK credentials. Otherwise returns an
unauthenticated session since presigned URLs already contain auth.
The session is created on first call and cached for reuse.
"""
if self._cached_cloud_provider_session is not None:
return self._cached_cloud_provider_session
if self._config.experimental_files_ext_enable_storage_proxy and self._dp_hostname_available:
return self._create_storage_proxy_session()
with self._cloud_provider_session_lock:
# Double-check after acquiring the lock to avoid creating duplicate sessions.
if self._cached_cloud_provider_session is not None:
return self._cached_cloud_provider_session
session = requests.Session()
# Following session config in _BaseClient.
http_adapter = requests.adapters.HTTPAdapter(
self._config.max_connection_pools or 20, self._config.max_connections_per_pool or 20, pool_block=True
)
session.mount("https://", http_adapter)
# Presigned URL for storage proxy can use plain HTTP.
session.mount("http://", http_adapter)
self._cached_cloud_provider_session = session
return self._cached_cloud_provider_session
if self._cached_cloud_provider_session is None:
session = requests.Session()
# Following session config in _BaseClient.
http_adapter = requests.adapters.HTTPAdapter(
self._config.max_connection_pools or 20,
self._config.max_connections_per_pool or 20,
pool_block=True,
)
session.mount("https://", http_adapter)
# Presigned URL for storage proxy can use plain HTTP.
session.mount("http://", http_adapter)
self._cached_cloud_provider_session = session
return self._cached_cloud_provider_session

def _retry_cloud_idempotent_operation(
self, operation: Callable[[], requests.Response], before_retry: Optional[Callable] = None
Expand Down
Loading
Loading