Skip to content

Commit 013456b

Browse files
Advertise Python worker admission capacity
Issue: zorporation/durable-workflow#488 Loop-ID: build-03
1 parent d9d7ff0 commit 013456b

5 files changed

Lines changed: 82 additions & 0 deletions

File tree

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ fingerprints during registration. Re-registering the same `worker_id` with a
156156
changed class body for an already advertised workflow type raises immediately;
157157
restart the worker process with a new id before serving changed workflow code.
158158

159+
Workers also advertise their local workflow and activity concurrency limits
160+
during registration. Tune `max_concurrent_workflow_tasks` and
161+
`max_concurrent_activity_tasks` on `Worker(...)` to align local semaphores with
162+
the server's task-queue admission and operator visibility surfaces.
163+
159164
## Replay captured histories
160165

161166
Use `Replayer` to debug a captured history without connecting to a live server:

src/durable_workflow/client.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,8 @@ async def register_worker(
11341134
supported_workflow_types: list[str] | None = None,
11351135
workflow_definition_fingerprints: dict[str, str] | None = None,
11361136
supported_activity_types: list[str] | None = None,
1137+
max_concurrent_workflow_tasks: int | None = None,
1138+
max_concurrent_activity_tasks: int | None = None,
11371139
runtime: str = "python",
11381140
sdk_version: str | None = None,
11391141
) -> Any:
@@ -1145,6 +1147,11 @@ async def register_worker(
11451147
"""
11461148
if sdk_version is None:
11471149
sdk_version = DEFAULT_SDK_VERSION
1150+
if max_concurrent_workflow_tasks is not None and max_concurrent_workflow_tasks < 1:
1151+
raise ValueError("max_concurrent_workflow_tasks must be at least 1")
1152+
if max_concurrent_activity_tasks is not None and max_concurrent_activity_tasks < 1:
1153+
raise ValueError("max_concurrent_activity_tasks must be at least 1")
1154+
11481155
body: dict[str, Any] = {
11491156
"worker_id": worker_id,
11501157
"task_queue": task_queue,
@@ -1155,6 +1162,10 @@ async def register_worker(
11551162
}
11561163
if workflow_definition_fingerprints is not None:
11571164
body["workflow_definition_fingerprints"] = workflow_definition_fingerprints
1165+
if max_concurrent_workflow_tasks is not None:
1166+
body["max_concurrent_workflow_tasks"] = max_concurrent_workflow_tasks
1167+
if max_concurrent_activity_tasks is not None:
1168+
body["max_concurrent_activity_tasks"] = max_concurrent_activity_tasks
11581169
return await self._request("POST", "/worker/register", worker=True, json=body)
11591170

11601171
async def poll_workflow_task(

src/durable_workflow/worker.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,14 @@ def __init__(
241241
self.activities = {_activity_name(a): a for a in activities}
242242
self.worker_id = worker_id or f"py-worker-{uuid.uuid4().hex[:8]}"
243243
_guard_worker_workflow_fingerprints(self.worker_id, self.workflow_definition_fingerprints)
244+
if max_concurrent_workflow_tasks < 1:
245+
raise ValueError("max_concurrent_workflow_tasks must be at least 1")
246+
if max_concurrent_activity_tasks < 1:
247+
raise ValueError("max_concurrent_activity_tasks must be at least 1")
248+
244249
self._poll_timeout = poll_timeout
250+
self.max_concurrent_workflow_tasks = max_concurrent_workflow_tasks
251+
self.max_concurrent_activity_tasks = max_concurrent_activity_tasks
245252
self._stop = asyncio.Event()
246253
self._wf_semaphore = asyncio.Semaphore(max_concurrent_workflow_tasks)
247254
self._act_semaphore = asyncio.Semaphore(max_concurrent_activity_tasks)
@@ -319,6 +326,8 @@ async def _register(self) -> None:
319326
supported_workflow_types=list(self.workflows),
320327
workflow_definition_fingerprints=self.workflow_definition_fingerprints,
321328
supported_activity_types=list(self.activities),
329+
max_concurrent_workflow_tasks=self.max_concurrent_workflow_tasks,
330+
max_concurrent_activity_tasks=self.max_concurrent_activity_tasks,
322331
)
323332
log.info("worker %s registered on %s", self.worker_id, self.task_queue)
324333

tests/test_client.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,36 @@ async def test_register(self, client: Client) -> None:
665665
assert body["runtime"] == "python"
666666
assert body["workflow_definition_fingerprints"] == {"greeter": "sha256:abc"}
667667

668+
@pytest.mark.asyncio
669+
async def test_register_sends_worker_capacity_when_configured(self, client: Client) -> None:
670+
resp = _mock_response(201, {"worker_id": "w1", "registered": True})
671+
with patch.object(client._http, "request", new_callable=AsyncMock, return_value=resp) as mock:
672+
await client.register_worker(
673+
worker_id="w1",
674+
task_queue="q1",
675+
max_concurrent_workflow_tasks=3,
676+
max_concurrent_activity_tasks=7,
677+
)
678+
body = mock.call_args.kwargs.get("json") or mock.call_args[1].get("json")
679+
assert body["max_concurrent_workflow_tasks"] == 3
680+
assert body["max_concurrent_activity_tasks"] == 7
681+
682+
@pytest.mark.asyncio
683+
async def test_register_rejects_non_positive_worker_capacity(self, client: Client) -> None:
684+
with pytest.raises(ValueError, match="max_concurrent_workflow_tasks"):
685+
await client.register_worker(
686+
worker_id="w1",
687+
task_queue="q1",
688+
max_concurrent_workflow_tasks=0,
689+
)
690+
691+
with pytest.raises(ValueError, match="max_concurrent_activity_tasks"):
692+
await client.register_worker(
693+
worker_id="w1",
694+
task_queue="q1",
695+
max_concurrent_activity_tasks=0,
696+
)
697+
668698
@pytest.mark.asyncio
669699
async def test_register_advertises_installed_package_version(self, client: Client) -> None:
670700
from importlib.metadata import version as _pkg_version

tests/test_worker.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,33 @@ async def test_register(self, mock_client: AsyncMock) -> None:
138138
assert "test-wf" in call_kwargs["supported_workflow_types"]
139139
assert call_kwargs["workflow_definition_fingerprints"]["test-wf"].startswith("sha256:")
140140
assert "test-act" in call_kwargs["supported_activity_types"]
141+
assert call_kwargs["max_concurrent_workflow_tasks"] == 10
142+
assert call_kwargs["max_concurrent_activity_tasks"] == 10
143+
144+
@pytest.mark.asyncio
145+
async def test_register_advertises_custom_concurrency_limits(self, mock_client: AsyncMock) -> None:
146+
worker = Worker(
147+
mock_client,
148+
task_queue="q1",
149+
workflows=[TestWorkflow],
150+
activities=[echo_activity],
151+
worker_id="w-capacity",
152+
max_concurrent_workflow_tasks=3,
153+
max_concurrent_activity_tasks=7,
154+
)
155+
await worker._register()
156+
call_kwargs = mock_client.register_worker.call_args.kwargs
157+
assert call_kwargs["max_concurrent_workflow_tasks"] == 3
158+
assert call_kwargs["max_concurrent_activity_tasks"] == 7
159+
160+
def test_constructor_rejects_non_positive_concurrency_limits(
161+
self, mock_client: AsyncMock
162+
) -> None:
163+
with pytest.raises(ValueError, match="max_concurrent_workflow_tasks"):
164+
Worker(mock_client, task_queue="q1", max_concurrent_workflow_tasks=0)
165+
166+
with pytest.raises(ValueError, match="max_concurrent_activity_tasks"):
167+
Worker(mock_client, task_queue="q1", max_concurrent_activity_tasks=0)
141168

142169
def test_constructor_rejects_changed_workflow_definition_for_same_worker_id(
143170
self, mock_client: AsyncMock

0 commit comments

Comments
 (0)