Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion src/runpod_flash/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
_POLL_MAX_INTERVAL = 5.0
_POLL_BACKOFF_FACTOR = 1.5

# max consecutive transient httpx errors tolerated during wait() polling
# before re-raising. resets on any successful poll.
_POLL_MAX_CONSECUTIVE_ERRORS = 5


class _ClientCoroutine:
"""wraps a coroutine from a client-mode HTTP call.
Expand Down Expand Up @@ -137,8 +141,11 @@ async def wait(self, timeout: Optional[float] = None) -> "EndpointJob":
import asyncio
import time

import httpx

deadline = (time.monotonic() + timeout) if timeout is not None else None
interval = _POLL_INITIAL_INTERVAL
consecutive_errors = 0

while not self.done:
if deadline is not None and time.monotonic() >= deadline:
Expand All @@ -152,7 +159,27 @@ async def wait(self, timeout: Optional[float] = None) -> "EndpointJob":
f"job {self.id} did not complete within {timeout}s "
f"(last status: {self._data.get('status', 'UNKNOWN')})"
)
await self.status()
try:
await self.status()
except (httpx.TransportError, httpx.TimeoutException) as e:
# transient network / protocol / timeout error from the
# runpod api. the underlying job is still healthy, so back
# off and retry rather than aborting wait().
# HTTPStatusError (4xx/5xx from raise_for_status) is NOT
# caught here: 4xx auth/config bugs must fail loud.
consecutive_errors += 1
log.debug(
"transient httpx error polling job %s (%d/%d): %s",
self.id,
consecutive_errors,
_POLL_MAX_CONSECUTIVE_ERRORS,
e,
)
if consecutive_errors >= _POLL_MAX_CONSECUTIVE_ERRORS:
raise
interval = min(interval * _POLL_BACKOFF_FACTOR, _POLL_MAX_INTERVAL)
continue
consecutive_errors = 0
interval = min(interval * _POLL_BACKOFF_FACTOR, _POLL_MAX_INTERVAL)

return self
Expand Down
126 changes: 126 additions & 0 deletions tests/unit/test_endpoint_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,132 @@ async def test_wait_timeout_raises(self):
await job.wait(timeout=0.3)


@pytest.fixture
def fast_poll(monkeypatch):
"""shrink the poll intervals so retry tests don't sit on real sleeps."""
monkeypatch.setattr("runpod_flash.endpoint._POLL_INITIAL_INTERVAL", 0.001)
monkeypatch.setattr("runpod_flash.endpoint._POLL_MAX_INTERVAL", 0.005)


class TestEndpointJobWaitTransientErrors:
"""retry behavior for transient httpx errors during wait() polling (AE-3154)."""

@staticmethod
def _make_job():
ep = Endpoint(id="ep-1")
ep._endpoint_url = "https://api.runpod.ai/v2/ep-1"
job = EndpointJob({"id": "j-1", "status": "IN_QUEUE"}, ep)
return ep, job

@pytest.mark.asyncio
async def test_transient_error_then_success(self, fast_poll):
"""one RemoteProtocolError then COMPLETED — wait() returns normally."""
import httpx

ep, job = self._make_job()

side_effects = [
httpx.RemoteProtocolError("server disconnected"),
{"id": "j-1", "status": "COMPLETED", "output": {"r": 1}},
]
ep._api_get = AsyncMock(side_effect=side_effects)

result = await job.wait()

assert result is job
assert job._data["status"] == "COMPLETED"
assert job.output == {"r": 1}
assert ep._api_get.call_count == 2

@pytest.mark.asyncio
async def test_repeated_transient_errors_exceed_threshold(self, fast_poll):
"""5 consecutive RemoteProtocolErrors — wait() re-raises the httpx error."""
import httpx

from runpod_flash.endpoint import _POLL_MAX_CONSECUTIVE_ERRORS

ep, job = self._make_job()
ep._api_get = AsyncMock(
side_effect=httpx.RemoteProtocolError("server disconnected")
)

with pytest.raises(httpx.RemoteProtocolError):
await job.wait()

assert ep._api_get.call_count == _POLL_MAX_CONSECUTIVE_ERRORS

@pytest.mark.asyncio
async def test_counter_resets_on_successful_poll(self, fast_poll):
"""error bursts under the threshold separated by successes do not abort."""
import httpx

ep, job = self._make_job()

side_effects = [
httpx.RemoteProtocolError("drop 1"),
httpx.RemoteProtocolError("drop 2"),
{"id": "j-1", "status": "IN_PROGRESS"},
httpx.RemoteProtocolError("drop 3"),
httpx.RemoteProtocolError("drop 4"),
httpx.RemoteProtocolError("drop 5"),
httpx.RemoteProtocolError("drop 6"),
{"id": "j-1", "status": "IN_PROGRESS"},
{"id": "j-1", "status": "COMPLETED", "output": {"r": 1}},
]
ep._api_get = AsyncMock(side_effect=side_effects)

result = await job.wait()

assert result is job
assert job._data["status"] == "COMPLETED"
assert ep._api_get.call_count == len(side_effects)

@pytest.mark.asyncio
async def test_http_status_error_not_swallowed(self, fast_poll):
"""4xx HTTPStatusError must propagate immediately (auth/config bugs)."""
import httpx

ep, job = self._make_job()

request = httpx.Request("GET", "https://api.runpod.ai/v2/ep-1/status/j-1")
response = httpx.Response(401, request=request)
ep._api_get = AsyncMock(
side_effect=httpx.HTTPStatusError(
"401 unauthorized", request=request, response=response
)
)

with pytest.raises(httpx.HTTPStatusError):
await job.wait()

# exactly one call: not retried
assert ep._api_get.call_count == 1

@pytest.mark.asyncio
async def test_timeout_still_authoritative(self, fast_poll, monkeypatch):
"""when deadline is hit before threshold, raise TimeoutError not httpx error.

Raises the threshold above the number of retries the deadline allows, so
the test actually exercises the retry path (multiple suppressed httpx
errors) before the deadline trips -- not just the pre-sleep guard.
"""
import httpx

monkeypatch.setattr("runpod_flash.endpoint._POLL_MAX_CONSECUTIVE_ERRORS", 1000)

ep, job = self._make_job()
ep._api_get = AsyncMock(
side_effect=httpx.RemoteProtocolError("server disconnected")
)

with pytest.raises(TimeoutError, match="did not complete within"):
await job.wait(timeout=0.05)

# proves the retry path was exercised: status() was called and the
# httpx error was suppressed at least once before the deadline tripped.
assert ep._api_get.call_count >= 2


# -- Endpoint.run / runsync / cancel --


Expand Down
Loading