Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/databricks/sql/auth/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
retry_stop_after_attempts_duration: Optional[float] = None,
retry_delay_default: Optional[float] = None,
retry_dangerous_codes: Optional[List[int]] = None,
respect_server_retry_after_header: Optional[bool] = None,
proxy_auth_method: Optional[str] = None,
pool_connections: Optional[int] = None,
pool_maxsize: Optional[int] = None,
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(
)
self.retry_delay_default = retry_delay_default or 5.0
self.retry_dangerous_codes = retry_dangerous_codes or []
self.respect_server_retry_after_header = bool(respect_server_retry_after_header)
self.proxy_auth_method = proxy_auth_method
self.pool_connections = pool_connections or 10
self.pool_maxsize = pool_maxsize or 20
Expand Down
18 changes: 16 additions & 2 deletions src/databricks/sql/auth/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
stop_after_attempts_duration: float,
delay_default: float,
force_dangerous_codes: List[int],
respect_server_retry_after_header: bool = False,
urllib3_kwargs: dict = {},
):
# These values do not change from one command to the next
Expand All @@ -103,6 +104,7 @@ def __init__(
self.stop_after_attempts_duration = stop_after_attempts_duration
self._delay_default = delay_default
self.force_dangerous_codes = force_dangerous_codes
self.respect_server_retry_after_header = respect_server_retry_after_header

# the urllib3 kwargs are a mix of configuration (some of which we override)
# and counters like `total` or `connect` which may change between successive retries
Expand Down Expand Up @@ -202,6 +204,7 @@ def new(
stop_after_attempts_duration=self.stop_after_attempts_duration,
delay_default=self.delay_default,
force_dangerous_codes=self.force_dangerous_codes,
respect_server_retry_after_header=self.respect_server_retry_after_header,
urllib3_kwargs={},
)

Expand Down Expand Up @@ -323,7 +326,9 @@ def get_backoff_time(self) -> float:

return proposed_backoff

def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]:
def should_retry(
self, method: str, status_code: int, has_retry_after: bool = False
) -> Tuple[bool, str]:
"""This method encapsulates the connector's approach to retries.

We always retry a request unless one of these conditions is met:
Expand Down Expand Up @@ -388,6 +393,15 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]:
if not self._is_method_retryable(method):
return False, "Only POST requests are retried"

# When respect_server_retry_after_header is enabled, only retry when the
# server explicitly signals it's safe via a Retry-After header. This prevents
# duplicate side effects for non-idempotent operations.
if self.respect_server_retry_after_header and not has_retry_after:
return (
False,
"respect_server_retry_after_header mode: no Retry-After header present",
)

# Request failed, was an ExecuteStatement and the command may have reached the server
if (
self.command_type == CommandType.EXECUTE_STATEMENT
Expand Down Expand Up @@ -430,7 +444,7 @@ def is_retry(
Logs a debug message if the request will be retried
"""

should_retry, msg = self.should_retry(method, status_code)
should_retry, msg = self.should_retry(method, status_code, has_retry_after)

if should_retry:
logger.debug(msg)
Expand Down
4 changes: 4 additions & 0 deletions src/databricks/sql/backend/sea/utils/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def __init__(
)
self._retry_delay_default = kwargs.get("_retry_delay_default", 5.0)
self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", [])
self._respect_server_retry_after_header = kwargs.get(
"_respect_server_retry_after_header", False
)

# Connection pooling settings
self.max_connections = kwargs.get("max_connections", 10)
Expand All @@ -116,6 +119,7 @@ def __init__(
stop_after_attempts_duration=self._retry_stop_after_attempts_duration,
delay_default=self._retry_delay_default,
force_dangerous_codes=self.force_dangerous_codes,
respect_server_retry_after_header=self._respect_server_retry_after_header,
urllib3_kwargs=urllib3_kwargs,
)
else:
Expand Down
4 changes: 4 additions & 0 deletions src/databricks/sql/backend/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ def __init__(
" This behaviour is deprecated and will be removed in a future release."
)
self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", [])
self._respect_server_retry_after_header = kwargs.get(
"_respect_server_retry_after_header", False
)

additional_transport_args = {}

Expand All @@ -217,6 +220,7 @@ def __init__(
stop_after_attempts_duration=self._retry_stop_after_attempts_duration,
delay_default=self._retry_delay_default,
force_dangerous_codes=self.force_dangerous_codes,
respect_server_retry_after_header=self._respect_server_retry_after_header,
urllib3_kwargs=urllib3_kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions src/databricks/sql/common/unified_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def _setup_pool_managers(self):
stop_after_attempts_duration=self.config.retry_stop_after_attempts_duration,
delay_default=self.config.retry_delay_default,
force_dangerous_codes=self.config.retry_dangerous_codes,
respect_server_retry_after_header=self.config.respect_server_retry_after_header,
)

# Initialize the required attributes that DatabricksRetryPolicy expects
Expand Down
3 changes: 3 additions & 0 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,9 @@ def build_client_context(server_hostname: str, version: str, **kwargs):
),
retry_delay_default=kwargs.get("_retry_delay_default"),
retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"),
respect_server_retry_after_header=kwargs.get(
"_respect_server_retry_after_header"
),
proxy_auth_method=kwargs.get("_proxy_auth_method"),
pool_connections=kwargs.get("_pool_connections"),
pool_maxsize=kwargs.get("_pool_maxsize"),
Expand Down
89 changes: 86 additions & 3 deletions tests/unit/test_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,21 @@


class TestRetry:
@pytest.fixture()
def retry_policy(self) -> DatabricksRetryPolicy:
return DatabricksRetryPolicy(
def _make_retry_policy(self, **overrides) -> DatabricksRetryPolicy:
defaults = dict(
delay_min=1,
delay_max=30,
stop_after_attempts_count=3,
stop_after_attempts_duration=900,
delay_default=2,
force_dangerous_codes=[],
)
defaults.update(overrides)
return DatabricksRetryPolicy(**defaults)

@pytest.fixture()
def retry_policy(self) -> DatabricksRetryPolicy:
return self._make_retry_policy()

@pytest.fixture()
def error_history(self) -> RequestHistory:
Expand Down Expand Up @@ -84,6 +89,84 @@ def test_excessive_retry_attempts_error(self, t_mock, retry_policy):
# Internally urllib3 calls the increment function generating a new instance for every retry
retry_policy = retry_policy.increment()

def test_respect_server_retry_after__retries_with_retry_after(self):
"""429 + Retry-After header → should retry"""
policy = self._make_retry_policy(respect_server_retry_after_header=True)
policy._retry_start_time = time.time()
policy.command_type = CommandType.OTHER
should_retry, msg = policy.should_retry("POST", 429, has_retry_after=True)
assert should_retry is True

def test_respect_server_retry_after__no_retry_without_retry_after(self):
"""429 without Retry-After header → no retry"""
policy = self._make_retry_policy(respect_server_retry_after_header=True)
policy._retry_start_time = time.time()
policy.command_type = CommandType.OTHER
should_retry, msg = policy.should_retry("POST", 429, has_retry_after=False)
assert should_retry is False
assert "respect_server_retry_after_header" in msg

def test_respect_server_retry_after__no_retry_503_without_header(self):
"""503 without Retry-After header → no retry"""
policy = self._make_retry_policy(respect_server_retry_after_header=True)
policy._retry_start_time = time.time()
policy.command_type = CommandType.OTHER
should_retry, msg = policy.should_retry("POST", 503, has_retry_after=False)
assert should_retry is False
assert "respect_server_retry_after_header" in msg

def test_respect_server_retry_after__overrides_dangerous_codes(self):
"""force_dangerous_codes=[500] + no Retry-After → no retry in respect_server_retry_after_header mode"""
policy = self._make_retry_policy(
force_dangerous_codes=[500], respect_server_retry_after_header=True
)
policy._retry_start_time = time.time()
policy.command_type = CommandType.EXECUTE_STATEMENT
should_retry, msg = policy.should_retry("POST", 500, has_retry_after=False)
assert should_retry is False
assert "respect_server_retry_after_header" in msg

def test_respect_server_retry_after__non_retryable_codes_unaffected(self):
"""401/403/501 still don't retry even with Retry-After header"""
policy = self._make_retry_policy(respect_server_retry_after_header=True)
policy._retry_start_time = time.time()
policy.command_type = CommandType.OTHER
for code in [401, 403, 501]:
should_retry, msg = policy.should_retry(
"POST", code, has_retry_after=True
)
assert should_retry is False, f"Code {code} should never retry"

def test_default_mode_unchanged(self, retry_policy):
"""respect_server_retry_after_header=False preserves existing behavior — 429 retries without header"""
retry_policy._retry_start_time = time.time()
retry_policy.command_type = CommandType.OTHER
should_retry, msg = retry_policy.should_retry(
"POST", 429, has_retry_after=False
)
assert should_retry is True

def test_respect_server_retry_after__survives_new(self):
"""urllib3 calls .new() between retries to create a fresh policy instance.
Verify that respect_server_retry_after_header is carried over and still enforced."""
policy = self._make_retry_policy(respect_server_retry_after_header=True)
policy._retry_start_time = time.time()
policy.command_type = CommandType.OTHER
new_policy = policy.new()
assert new_policy.respect_server_retry_after_header is True
# The new instance should still block retries without Retry-After
should_retry, msg = new_policy.should_retry("POST", 429, has_retry_after=False)
assert should_retry is False
assert "respect_server_retry_after_header" in msg

def test_respect_server_retry_after__execute_statement_with_retry_after(self):
"""EXECUTE_STATEMENT + 429 + Retry-After header → retry"""
policy = self._make_retry_policy(respect_server_retry_after_header=True)
policy._retry_start_time = time.time()
policy.command_type = CommandType.EXECUTE_STATEMENT
should_retry, msg = policy.should_retry("POST", 429, has_retry_after=True)
assert should_retry is True

def test_404_does_not_retry_for_any_command_type(self, retry_policy):
"""Test that 404 never retries for any CommandType"""
retry_policy._retry_start_time = time.time()
Expand Down
41 changes: 27 additions & 14 deletions tests/unit/test_unified_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def client_context(self):
context.retry_stop_after_attempts_duration = 300.0
context.retry_delay_default = 5.0
context.retry_dangerous_codes = []
context.respect_server_retry_after_header = False
context.proxy_auth_method = None
context.pool_connections = 10
context.pool_maxsize = 20
Expand All @@ -48,16 +49,19 @@ def http_client(self, client_context):
"""Create UnifiedHttpClient instance."""
return UnifiedHttpClient(client_context)

@pytest.mark.parametrize("status_code,path", [
(429, "reason.response"),
(503, "reason.response"),
(500, "direct_response"),
])
@pytest.mark.parametrize(
"status_code,path",
[
(429, "reason.response"),
(503, "reason.response"),
(500, "direct_response"),
],
)
def test_max_retry_error_with_status_codes(self, http_client, status_code, path):
"""Test MaxRetryError with various status codes and response paths."""
mock_pool = Mock()
max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com")

if path == "reason.response":
max_retry_error.reason = Mock()
max_retry_error.reason.response = Mock()
Expand All @@ -79,12 +83,21 @@ def test_max_retry_error_with_status_codes(self, http_client, status_code, path)
assert "http-code" in error.context
assert error.context["http-code"] == status_code

@pytest.mark.parametrize("setup_func", [
lambda e: None, # No setup - error with no attributes
lambda e: setattr(e, "reason", None), # reason=None
lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", None)), # reason.response=None
lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", Mock(spec=[]))), # No status attr
])
@pytest.mark.parametrize(
"setup_func",
[
lambda e: None, # No setup - error with no attributes
lambda e: setattr(e, "reason", None), # reason=None
lambda e: (
setattr(e, "reason", Mock()),
setattr(e.reason, "response", None),
), # reason.response=None
lambda e: (
setattr(e, "reason", Mock()),
setattr(e.reason, "response", Mock(spec=[])),
), # No status attr
],
)
def test_max_retry_error_missing_status(self, http_client, setup_func):
"""Test MaxRetryError without status code (no crash, empty context)."""
mock_pool = Mock()
Expand All @@ -104,12 +117,12 @@ def test_max_retry_error_prefers_reason_response(self, http_client):
"""Test that e.reason.response.status is preferred over e.response.status."""
mock_pool = Mock()
max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com")

# Set both structures with different status codes
max_retry_error.reason = Mock()
max_retry_error.reason.response = Mock()
max_retry_error.reason.response.status = 429 # Should use this

max_retry_error.response = Mock()
max_retry_error.response.status = 500 # Should be ignored

Expand Down
Loading