Skip to content

Commit 179140a

Browse files
committed
Add _retry_server_directed_only mode for Retry-After header compliance
When enabled, the connector only retries on 429/503 if the server includes a Retry-After header in the response. This prevents duplicate side effects for non-idempotent ExecuteStatement operations where the server has not explicitly signaled that retry is safe. The new opt-in parameter `_retry_server_directed_only` threads through ClientContext, all three DatabricksRetryPolicy construction sites (Thrift, SEA, UnifiedHttpClient), and the retry policy's should_retry/is_retry methods. Default behavior (retry without requiring the header) is unchanged. Signed-off-by: Shubham Dhal <shubham.dhal@databricks.com>
1 parent ca4d7bc commit 179140a

File tree

8 files changed

+160
-16
lines changed

8 files changed

+160
-16
lines changed

src/databricks/sql/auth/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
retry_stop_after_attempts_duration: Optional[float] = None,
4848
retry_delay_default: Optional[float] = None,
4949
retry_dangerous_codes: Optional[List[int]] = None,
50+
retry_server_directed_only: Optional[bool] = None,
5051
proxy_auth_method: Optional[str] = None,
5152
pool_connections: Optional[int] = None,
5253
pool_maxsize: Optional[int] = None,
@@ -80,6 +81,7 @@ def __init__(
8081
)
8182
self.retry_delay_default = retry_delay_default or 5.0
8283
self.retry_dangerous_codes = retry_dangerous_codes or []
84+
self.retry_server_directed_only = bool(retry_server_directed_only)
8385
self.proxy_auth_method = proxy_auth_method
8486
self.pool_connections = pool_connections or 10
8587
self.pool_maxsize = pool_maxsize or 20

src/databricks/sql/auth/retry.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(
9494
stop_after_attempts_duration: float,
9595
delay_default: float,
9696
force_dangerous_codes: List[int],
97+
server_directed_only: bool = False,
9798
urllib3_kwargs: dict = {},
9899
):
99100
# These values do not change from one command to the next
@@ -103,6 +104,7 @@ def __init__(
103104
self.stop_after_attempts_duration = stop_after_attempts_duration
104105
self._delay_default = delay_default
105106
self.force_dangerous_codes = force_dangerous_codes
107+
self.server_directed_only = server_directed_only
106108

107109
# the urllib3 kwargs are a mix of configuration (some of which we override)
108110
# and counters like `total` or `connect` which may change between successive retries
@@ -202,6 +204,7 @@ def new(
202204
stop_after_attempts_duration=self.stop_after_attempts_duration,
203205
delay_default=self.delay_default,
204206
force_dangerous_codes=self.force_dangerous_codes,
207+
server_directed_only=self.server_directed_only,
205208
urllib3_kwargs={},
206209
)
207210

@@ -323,7 +326,9 @@ def get_backoff_time(self) -> float:
323326

324327
return proposed_backoff
325328

326-
def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]:
329+
def should_retry(
330+
self, method: str, status_code: int, has_retry_after: bool = False
331+
) -> Tuple[bool, str]:
327332
"""This method encapsulates the connector's approach to retries.
328333
329334
We always retry a request unless one of these conditions is met:
@@ -388,6 +393,12 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]:
388393
if not self._is_method_retryable(method):
389394
return False, "Only POST requests are retried"
390395

396+
# In server_directed_only mode, only retry when the server explicitly signals
397+
# it's safe via a Retry-After header. This prevents duplicate side effects for
398+
# non-idempotent operations.
399+
if self.server_directed_only and not has_retry_after:
400+
return (False, "server_directed_only mode: no Retry-After header present")
401+
391402
# Request failed, was an ExecuteStatement and the command may have reached the server
392403
if (
393404
self.command_type == CommandType.EXECUTE_STATEMENT
@@ -430,7 +441,7 @@ def is_retry(
430441
Logs a debug message if the request will be retried
431442
"""
432443

433-
should_retry, msg = self.should_retry(method, status_code)
444+
should_retry, msg = self.should_retry(method, status_code, has_retry_after)
434445

435446
if should_retry:
436447
logger.debug(msg)

src/databricks/sql/backend/sea/utils/http_client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ def __init__(
9292
)
9393
self._retry_delay_default = kwargs.get("_retry_delay_default", 5.0)
9494
self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", [])
95+
self._retry_server_directed_only = kwargs.get(
96+
"_retry_server_directed_only", False
97+
)
9598

9699
# Connection pooling settings
97100
self.max_connections = kwargs.get("max_connections", 10)
@@ -116,6 +119,7 @@ def __init__(
116119
stop_after_attempts_duration=self._retry_stop_after_attempts_duration,
117120
delay_default=self._retry_delay_default,
118121
force_dangerous_codes=self.force_dangerous_codes,
122+
server_directed_only=self._retry_server_directed_only,
119123
urllib3_kwargs=urllib3_kwargs,
120124
)
121125
else:

src/databricks/sql/backend/thrift_backend.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,9 @@ def __init__(
191191
" This behaviour is deprecated and will be removed in a future release."
192192
)
193193
self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", [])
194+
self._retry_server_directed_only = kwargs.get(
195+
"_retry_server_directed_only", False
196+
)
194197

195198
additional_transport_args = {}
196199

@@ -217,6 +220,7 @@ def __init__(
217220
stop_after_attempts_duration=self._retry_stop_after_attempts_duration,
218221
delay_default=self._retry_delay_default,
219222
force_dangerous_codes=self.force_dangerous_codes,
223+
server_directed_only=self._retry_server_directed_only,
220224
urllib3_kwargs=urllib3_kwargs,
221225
)
222226

src/databricks/sql/common/unified_http_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def _setup_pool_managers(self):
135135
stop_after_attempts_duration=self.config.retry_stop_after_attempts_duration,
136136
delay_default=self.config.retry_delay_default,
137137
force_dangerous_codes=self.config.retry_dangerous_codes,
138+
server_directed_only=self.config.retry_server_directed_only,
138139
)
139140

140141
# Initialize the required attributes that DatabricksRetryPolicy expects

src/databricks/sql/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,7 @@ def build_client_context(server_hostname: str, version: str, **kwargs):
977977
),
978978
retry_delay_default=kwargs.get("_retry_delay_default"),
979979
retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"),
980+
retry_server_directed_only=kwargs.get("_retry_server_directed_only"),
980981
proxy_auth_method=kwargs.get("_proxy_auth_method"),
981982
pool_connections=kwargs.get("_pool_connections"),
982983
pool_maxsize=kwargs.get("_pool_maxsize"),

tests/unit/test_retry.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,114 @@ def test_excessive_retry_attempts_error(self, t_mock, retry_policy):
8484
# Internally urllib3 calls the increment function generating a new instance for every retry
8585
retry_policy = retry_policy.increment()
8686

87+
@pytest.fixture()
88+
def server_directed_retry_policy(self) -> DatabricksRetryPolicy:
89+
return DatabricksRetryPolicy(
90+
delay_min=1,
91+
delay_max=30,
92+
stop_after_attempts_count=3,
93+
stop_after_attempts_duration=900,
94+
delay_default=2,
95+
force_dangerous_codes=[],
96+
server_directed_only=True,
97+
)
98+
99+
def test_server_directed_only__retries_with_retry_after(
100+
self, server_directed_retry_policy
101+
):
102+
"""429 + Retry-After header → should retry"""
103+
server_directed_retry_policy._retry_start_time = time.time()
104+
server_directed_retry_policy.command_type = CommandType.OTHER
105+
should_retry, msg = server_directed_retry_policy.should_retry(
106+
"POST", 429, has_retry_after=True
107+
)
108+
assert should_retry is True
109+
110+
def test_server_directed_only__no_retry_without_retry_after(
111+
self, server_directed_retry_policy
112+
):
113+
"""429 without Retry-After header → no retry"""
114+
server_directed_retry_policy._retry_start_time = time.time()
115+
server_directed_retry_policy.command_type = CommandType.OTHER
116+
should_retry, msg = server_directed_retry_policy.should_retry(
117+
"POST", 429, has_retry_after=False
118+
)
119+
assert should_retry is False
120+
assert "server_directed_only" in msg
121+
122+
def test_server_directed_only__no_retry_503_without_header(
123+
self, server_directed_retry_policy
124+
):
125+
"""503 without Retry-After header → no retry"""
126+
server_directed_retry_policy._retry_start_time = time.time()
127+
server_directed_retry_policy.command_type = CommandType.OTHER
128+
should_retry, msg = server_directed_retry_policy.should_retry(
129+
"POST", 503, has_retry_after=False
130+
)
131+
assert should_retry is False
132+
assert "server_directed_only" in msg
133+
134+
def test_server_directed_only__overrides_dangerous_codes(self):
135+
"""force_dangerous_codes=[500] + no Retry-After → no retry in server_directed_only mode"""
136+
policy = DatabricksRetryPolicy(
137+
delay_min=1,
138+
delay_max=30,
139+
stop_after_attempts_count=3,
140+
stop_after_attempts_duration=900,
141+
delay_default=2,
142+
force_dangerous_codes=[500],
143+
server_directed_only=True,
144+
)
145+
policy._retry_start_time = time.time()
146+
policy.command_type = CommandType.EXECUTE_STATEMENT
147+
should_retry, msg = policy.should_retry("POST", 500, has_retry_after=False)
148+
assert should_retry is False
149+
assert "server_directed_only" in msg
150+
151+
def test_server_directed_only__non_retryable_codes_unaffected(
152+
self, server_directed_retry_policy
153+
):
154+
"""401/403/501 still don't retry even with Retry-After header"""
155+
server_directed_retry_policy._retry_start_time = time.time()
156+
server_directed_retry_policy.command_type = CommandType.OTHER
157+
for code in [401, 403, 501]:
158+
should_retry, msg = server_directed_retry_policy.should_retry(
159+
"POST", code, has_retry_after=True
160+
)
161+
assert should_retry is False, f"Code {code} should never retry"
162+
163+
def test_default_mode_unchanged(self, retry_policy):
164+
"""server_directed_only=False preserves existing behavior — 429 retries without header"""
165+
retry_policy._retry_start_time = time.time()
166+
retry_policy.command_type = CommandType.OTHER
167+
should_retry, msg = retry_policy.should_retry(
168+
"POST", 429, has_retry_after=False
169+
)
170+
assert should_retry is True
171+
172+
def test_server_directed_only__survives_new(self, server_directed_retry_policy):
173+
"""urllib3 calls .new() between retries to create a fresh policy instance.
174+
Verify that server_directed_only is carried over and still enforced."""
175+
server_directed_retry_policy._retry_start_time = time.time()
176+
server_directed_retry_policy.command_type = CommandType.OTHER
177+
new_policy = server_directed_retry_policy.new()
178+
assert new_policy.server_directed_only is True
179+
# The new instance should still block retries without Retry-After
180+
should_retry, msg = new_policy.should_retry("POST", 429, has_retry_after=False)
181+
assert should_retry is False
182+
assert "server_directed_only" in msg
183+
184+
def test_server_directed_only__execute_statement_with_retry_after(
185+
self, server_directed_retry_policy
186+
):
187+
"""EXECUTE_STATEMENT + 429 + Retry-After header → retry"""
188+
server_directed_retry_policy._retry_start_time = time.time()
189+
server_directed_retry_policy.command_type = CommandType.EXECUTE_STATEMENT
190+
should_retry, msg = server_directed_retry_policy.should_retry(
191+
"POST", 429, has_retry_after=True
192+
)
193+
assert should_retry is True
194+
87195
def test_404_does_not_retry_for_any_command_type(self, retry_policy):
88196
"""Test that 404 never retries for any CommandType"""
89197
retry_policy._retry_start_time = time.time()

tests/unit/test_unified_http_client.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def client_context(self):
3737
context.retry_stop_after_attempts_duration = 300.0
3838
context.retry_delay_default = 5.0
3939
context.retry_dangerous_codes = []
40+
context.retry_server_directed_only = False
4041
context.proxy_auth_method = None
4142
context.pool_connections = 10
4243
context.pool_maxsize = 20
@@ -48,16 +49,19 @@ def http_client(self, client_context):
4849
"""Create UnifiedHttpClient instance."""
4950
return UnifiedHttpClient(client_context)
5051

51-
@pytest.mark.parametrize("status_code,path", [
52-
(429, "reason.response"),
53-
(503, "reason.response"),
54-
(500, "direct_response"),
55-
])
52+
@pytest.mark.parametrize(
53+
"status_code,path",
54+
[
55+
(429, "reason.response"),
56+
(503, "reason.response"),
57+
(500, "direct_response"),
58+
],
59+
)
5660
def test_max_retry_error_with_status_codes(self, http_client, status_code, path):
5761
"""Test MaxRetryError with various status codes and response paths."""
5862
mock_pool = Mock()
5963
max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com")
60-
64+
6165
if path == "reason.response":
6266
max_retry_error.reason = Mock()
6367
max_retry_error.reason.response = Mock()
@@ -79,12 +83,21 @@ def test_max_retry_error_with_status_codes(self, http_client, status_code, path)
7983
assert "http-code" in error.context
8084
assert error.context["http-code"] == status_code
8185

82-
@pytest.mark.parametrize("setup_func", [
83-
lambda e: None, # No setup - error with no attributes
84-
lambda e: setattr(e, "reason", None), # reason=None
85-
lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", None)), # reason.response=None
86-
lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", Mock(spec=[]))), # No status attr
87-
])
86+
@pytest.mark.parametrize(
87+
"setup_func",
88+
[
89+
lambda e: None, # No setup - error with no attributes
90+
lambda e: setattr(e, "reason", None), # reason=None
91+
lambda e: (
92+
setattr(e, "reason", Mock()),
93+
setattr(e.reason, "response", None),
94+
), # reason.response=None
95+
lambda e: (
96+
setattr(e, "reason", Mock()),
97+
setattr(e.reason, "response", Mock(spec=[])),
98+
), # No status attr
99+
],
100+
)
88101
def test_max_retry_error_missing_status(self, http_client, setup_func):
89102
"""Test MaxRetryError without status code (no crash, empty context)."""
90103
mock_pool = Mock()
@@ -104,12 +117,12 @@ def test_max_retry_error_prefers_reason_response(self, http_client):
104117
"""Test that e.reason.response.status is preferred over e.response.status."""
105118
mock_pool = Mock()
106119
max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com")
107-
120+
108121
# Set both structures with different status codes
109122
max_retry_error.reason = Mock()
110123
max_retry_error.reason.response = Mock()
111124
max_retry_error.reason.response.status = 429 # Should use this
112-
125+
113126
max_retry_error.response = Mock()
114127
max_retry_error.response.status = 500 # Should be ignored
115128

0 commit comments

Comments
 (0)