Skip to content

Commit df75ba1

Browse files
committed
Add _retry_server_directed_only mode for Retry-After header compliance
When enabled, the connector only retries on 429/503 if the server includes a Retry-After header in the response. This prevents duplicate side effects for non-idempotent ExecuteStatement operations where the server has not explicitly signaled that retry is safe. The new opt-in parameter `_retry_server_directed_only` threads through ClientContext, all three DatabricksRetryPolicy construction sites (Thrift, SEA, UnifiedHttpClient), and the retry policy's should_retry/is_retry methods. Default behavior (retry without requiring the header) is unchanged. Signed-off-by: Shubham Dhal <shubham.dhal@databricks.com>
1 parent 30286ad commit df75ba1

File tree

7 files changed

+145
-2
lines changed

7 files changed

+145
-2
lines changed

src/databricks/sql/auth/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
retry_stop_after_attempts_duration: Optional[float] = None,
4848
retry_delay_default: Optional[float] = None,
4949
retry_dangerous_codes: Optional[List[int]] = None,
50+
retry_server_directed_only: Optional[bool] = None,
5051
proxy_auth_method: Optional[str] = None,
5152
pool_connections: Optional[int] = None,
5253
pool_maxsize: Optional[int] = None,
@@ -79,6 +80,7 @@ def __init__(
7980
)
8081
self.retry_delay_default = retry_delay_default or 5.0
8182
self.retry_dangerous_codes = retry_dangerous_codes or []
83+
self.retry_server_directed_only = bool(retry_server_directed_only)
8284
self.proxy_auth_method = proxy_auth_method
8385
self.pool_connections = pool_connections or 10
8486
self.pool_maxsize = pool_maxsize or 20

src/databricks/sql/auth/retry.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(
9494
stop_after_attempts_duration: float,
9595
delay_default: float,
9696
force_dangerous_codes: List[int],
97+
server_directed_only: bool = False,
9798
urllib3_kwargs: dict = {},
9899
):
99100
# These values do not change from one command to the next
@@ -103,6 +104,7 @@ def __init__(
103104
self.stop_after_attempts_duration = stop_after_attempts_duration
104105
self._delay_default = delay_default
105106
self.force_dangerous_codes = force_dangerous_codes
107+
self.server_directed_only = server_directed_only
106108

107109
# the urllib3 kwargs are a mix of configuration (some of which we override)
108110
# and counters like `total` or `connect` which may change between successive retries
@@ -202,6 +204,7 @@ def new(
202204
stop_after_attempts_duration=self.stop_after_attempts_duration,
203205
delay_default=self.delay_default,
204206
force_dangerous_codes=self.force_dangerous_codes,
207+
server_directed_only=self.server_directed_only,
205208
urllib3_kwargs={},
206209
)
207210

@@ -323,7 +326,9 @@ def get_backoff_time(self) -> float:
323326

324327
return proposed_backoff
325328

326-
def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]:
329+
def should_retry(
330+
self, method: str, status_code: int, has_retry_after: bool = False
331+
) -> Tuple[bool, str]:
327332
"""This method encapsulates the connector's approach to retries.
328333
329334
We always retry a request unless one of these conditions is met:
@@ -381,6 +386,12 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]:
381386
if not self._is_method_retryable(method):
382387
return False, "Only POST requests are retried"
383388

389+
# In server_directed_only mode, only retry when the server explicitly signals
390+
# it's safe via a Retry-After header. This prevents duplicate side effects for
391+
# non-idempotent operations.
392+
if self.server_directed_only and not has_retry_after:
393+
return (False, "server_directed_only mode: no Retry-After header present")
394+
384395
# Request failed with 404 and was a GetOperationStatus. This is not recoverable. Don't retry.
385396
if status_code == 404 and self.command_type == CommandType.GET_OPERATION_STATUS:
386397
return (
@@ -450,7 +461,7 @@ def is_retry(
450461
Logs a debug message if the request will be retried
451462
"""
452463

453-
should_retry, msg = self.should_retry(method, status_code)
464+
should_retry, msg = self.should_retry(method, status_code, has_retry_after)
454465

455466
if should_retry:
456467
logger.debug(msg)

src/databricks/sql/backend/sea/utils/http_client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def __init__(
9090
)
9191
self._retry_delay_default = kwargs.get("_retry_delay_default", 5.0)
9292
self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", [])
93+
self._retry_server_directed_only = kwargs.get(
94+
"_retry_server_directed_only", False
95+
)
9396

9497
# Connection pooling settings
9598
self.max_connections = kwargs.get("max_connections", 10)
@@ -114,6 +117,7 @@ def __init__(
114117
stop_after_attempts_duration=self._retry_stop_after_attempts_duration,
115118
delay_default=self._retry_delay_default,
116119
force_dangerous_codes=self.force_dangerous_codes,
120+
server_directed_only=self._retry_server_directed_only,
117121
urllib3_kwargs=urllib3_kwargs,
118122
)
119123
else:

src/databricks/sql/backend/thrift_backend.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,9 @@ def __init__(
189189
" This behaviour is deprecated and will be removed in a future release."
190190
)
191191
self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", [])
192+
self._retry_server_directed_only = kwargs.get(
193+
"_retry_server_directed_only", False
194+
)
192195

193196
additional_transport_args = {}
194197

@@ -215,6 +218,7 @@ def __init__(
215218
stop_after_attempts_duration=self._retry_stop_after_attempts_duration,
216219
delay_default=self._retry_delay_default,
217220
force_dangerous_codes=self.force_dangerous_codes,
221+
server_directed_only=self._retry_server_directed_only,
218222
urllib3_kwargs=urllib3_kwargs,
219223
)
220224

src/databricks/sql/common/unified_http_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def _setup_pool_managers(self):
9999
stop_after_attempts_duration=self.config.retry_stop_after_attempts_duration,
100100
delay_default=self.config.retry_delay_default,
101101
force_dangerous_codes=self.config.retry_dangerous_codes,
102+
server_directed_only=self.config.retry_server_directed_only,
102103
)
103104

104105
# Initialize the required attributes that DatabricksRetryPolicy expects

src/databricks/sql/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,7 @@ def build_client_context(server_hostname: str, version: str, **kwargs):
919919
),
920920
retry_delay_default=kwargs.get("_retry_delay_default"),
921921
retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"),
922+
retry_server_directed_only=kwargs.get("_retry_server_directed_only"),
922923
proxy_auth_method=kwargs.get("_proxy_auth_method"),
923924
pool_connections=kwargs.get("_pool_connections"),
924925
pool_maxsize=kwargs.get("_pool_maxsize"),

tests/unit/test_retry.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,123 @@ def test_excessive_retry_attempts_error(self, t_mock, retry_policy):
8383
retry_policy.sleep(HTTPResponse(status=503))
8484
# Internally urllib3 calls the increment function generating a new instance for every retry
8585
retry_policy = retry_policy.increment()
86+
87+
@pytest.fixture()
88+
def server_directed_retry_policy(self) -> DatabricksRetryPolicy:
89+
return DatabricksRetryPolicy(
90+
delay_min=1,
91+
delay_max=30,
92+
stop_after_attempts_count=3,
93+
stop_after_attempts_duration=900,
94+
delay_default=2,
95+
force_dangerous_codes=[],
96+
server_directed_only=True,
97+
)
98+
99+
def test_server_directed_only__retries_with_retry_after(
100+
self, server_directed_retry_policy
101+
):
102+
"""429 + Retry-After header → should retry"""
103+
server_directed_retry_policy._retry_start_time = time.time()
104+
server_directed_retry_policy.command_type = CommandType.OTHER
105+
should_retry, msg = server_directed_retry_policy.should_retry(
106+
"POST", 429, has_retry_after=True
107+
)
108+
assert should_retry is True
109+
110+
def test_server_directed_only__no_retry_without_retry_after(
111+
self, server_directed_retry_policy
112+
):
113+
"""429 without Retry-After header → no retry"""
114+
server_directed_retry_policy._retry_start_time = time.time()
115+
server_directed_retry_policy.command_type = CommandType.OTHER
116+
should_retry, msg = server_directed_retry_policy.should_retry(
117+
"POST", 429, has_retry_after=False
118+
)
119+
assert should_retry is False
120+
assert "server_directed_only" in msg
121+
122+
def test_server_directed_only__no_retry_503_without_header(
123+
self, server_directed_retry_policy
124+
):
125+
"""503 without Retry-After header → no retry"""
126+
server_directed_retry_policy._retry_start_time = time.time()
127+
server_directed_retry_policy.command_type = CommandType.OTHER
128+
should_retry, msg = server_directed_retry_policy.should_retry(
129+
"POST", 503, has_retry_after=False
130+
)
131+
assert should_retry is False
132+
assert "server_directed_only" in msg
133+
134+
def test_server_directed_only__overrides_dangerous_codes(self):
135+
"""force_dangerous_codes=[500] + no Retry-After → no retry in server_directed_only mode"""
136+
policy = DatabricksRetryPolicy(
137+
delay_min=1,
138+
delay_max=30,
139+
stop_after_attempts_count=3,
140+
stop_after_attempts_duration=900,
141+
delay_default=2,
142+
force_dangerous_codes=[500],
143+
server_directed_only=True,
144+
)
145+
policy._retry_start_time = time.time()
146+
policy.command_type = CommandType.EXECUTE_STATEMENT
147+
should_retry, msg = policy.should_retry("POST", 500, has_retry_after=False)
148+
assert should_retry is False
149+
assert "server_directed_only" in msg
150+
151+
def test_server_directed_only__non_retryable_codes_unaffected(
152+
self, server_directed_retry_policy
153+
):
154+
"""401/403/501 still don't retry even with Retry-After header"""
155+
server_directed_retry_policy._retry_start_time = time.time()
156+
server_directed_retry_policy.command_type = CommandType.OTHER
157+
for code in [401, 403, 501]:
158+
should_retry, msg = server_directed_retry_policy.should_retry(
159+
"POST", code, has_retry_after=True
160+
)
161+
assert should_retry is False, f"Code {code} should never retry"
162+
163+
def test_default_mode_unchanged(self, retry_policy):
164+
"""server_directed_only=False preserves existing behavior — 429 retries without header"""
165+
retry_policy._retry_start_time = time.time()
166+
retry_policy.command_type = CommandType.OTHER
167+
should_retry, msg = retry_policy.should_retry(
168+
"POST", 429, has_retry_after=False
169+
)
170+
assert should_retry is True
171+
172+
def test_server_directed_only__survives_new(self, server_directed_retry_policy):
173+
"""urllib3 calls .new() between retries to create a fresh policy instance.
174+
Verify that server_directed_only is carried over and still enforced."""
175+
server_directed_retry_policy._retry_start_time = time.time()
176+
server_directed_retry_policy.command_type = CommandType.OTHER
177+
new_policy = server_directed_retry_policy.new()
178+
assert new_policy.server_directed_only is True
179+
# The new instance should still block retries without Retry-After
180+
should_retry, msg = new_policy.should_retry("POST", 429, has_retry_after=False)
181+
assert should_retry is False
182+
assert "server_directed_only" in msg
183+
184+
def test_server_directed_only__execute_statement_with_retry_after(
185+
self, server_directed_retry_policy
186+
):
187+
"""EXECUTE_STATEMENT + 429 + Retry-After header → retry"""
188+
server_directed_retry_policy._retry_start_time = time.time()
189+
server_directed_retry_policy.command_type = CommandType.EXECUTE_STATEMENT
190+
should_retry, msg = server_directed_retry_policy.should_retry(
191+
"POST", 429, has_retry_after=True
192+
)
193+
assert should_retry is True
194+
195+
def test_404_does_not_retry_for_any_command_type(self, retry_policy):
196+
"""Test that 404 never retries for any CommandType"""
197+
retry_policy._retry_start_time = time.time()
198+
199+
# Test for each CommandType
200+
for command_type in CommandType:
201+
retry_policy.command_type = command_type
202+
should_retry, msg = retry_policy.should_retry("POST", 404)
203+
204+
assert should_retry is False, f"404 should not retry for {command_type}"
205+
assert "404" in msg or "NOT_FOUND" in msg

0 commit comments

Comments
 (0)