Skip to content

Commit 0429f4b

Browse files
Add custom_worker_tuner sample
1 parent 59e1f87 commit 0429f4b

7 files changed

Lines changed: 398 additions & 0 deletions

File tree

custom_worker_tuner/README.md

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Custom Worker Tuner
2+
3+
## Why we need a custom slot supplier
4+
5+
Neither knows about anything *outside* the worker. If you have a
6+
**downstream resource with a hard limit** — a database connection pool
7+
of 10, an API quota, a rate-limited service — those suppliers will
8+
happily accept more activities than the downstream can serve. The
9+
extra activities then stall *inside the worker*, holding slots and
10+
RAM
11+
12+
A `CustomSlotSupplier` lets you gate slot grants on whatever you want.
13+
This sample gates on a fake DB pool: the worker only polls for a new
14+
activity when the pool has a free connection.
15+
16+
## What this sample is
17+
downstream.py - A static-capacity counter. Pretends to be a DB pool. Two methods: increment() (claim a slot, returns False if full), decrement() (release)
18+
supplier.py - The custom slot supplier. On reserve_slot it polls downstream.increment() until it succeeds. On release_slot it calls downstream.decrement()
19+
shared.py - A RunBatch workflow that runs N do_work activities in parallel. The activity just sleeps
20+
worker.py - Wires Downstream + DownstreamAwareSupplier into a WorkerTuner
21+
starter.py - Drives load
22+
23+
The flow:
24+
25+
When the downstream is at capacity, `reserve_slot` blocks until a
26+
slot frees up. The excess work piles up on the Temporal server, not
27+
inside the worker.
28+
29+
## Run
30+
31+
In three terminals from `samples-python/`:
32+
33+
```bash
34+
temporal server start-dev # terminal 1
35+
uv run custom_worker_tuner/worker.py # terminal 2
36+
uv run custom_worker_tuner/starter.py # terminal 3
37+
```
38+
39+
## What you'll see
40+
41+
The worker prints one line per slot lifecycle event:
42+
43+
```
44+
45+
TIME EVENT SLOT COUNT DETAIL
46+
────────────────────────────────────────────────────────────
47+
10:31:49.842 reserve #1 1/10 ready to poll
48+
10:31:49.842 reserve #2 2/10 ready to poll
49+
10:31:49.843 reserve #3 3/10 ready to poll
50+
10:31:49.843 reserve #4 4/10 ready to poll
51+
10:31:49.843 reserve #5 5/10 ready to poll
52+
10:31:49.843 reserve #6 6/10 ready to poll
53+
10:31:56.763 reserve #7 7/10 eager dispatch
54+
10:31:56.763 reserve #8 8/10 eager dispatch
55+
10:31:56.764 reserve #9 9/10 eager dispatch
56+
10:31:56.766 reserve #10 10/10 eager dispatch
57+
10:31:56.767 release #7 9/10 no task arrived
58+
10:31:56.768 release #8 8/10 no task arrived
59+
10:31:56.768 release #9 7/10 no task arrived
60+
10:31:56.768 reserve #11 8/10 eager dispatch
61+
10:31:56.768 reserve #12 9/10 eager dispatch
62+
10:31:56.768 reserve #13 10/10 eager dispatch
63+
10:31:56.771 used #1 10/10 activity running
64+
10:31:56.771 release #10 9/10 no task arrived
65+
```
66+
67+
Under load, with more activities than capacity, COUNT pins at
68+
10/10 — that's the supplier refusing to poll past the gate.
69+
we chose 10 because default there are 5 pollers for python sdk
70+
71+
## Knobs
72+
73+
worker.py:
74+
75+
CAPACITY — downstream capacity (the gate)
76+
POLL_INTERVAL_MS — how often the supplier rechecks when full
77+
78+
starter.py:
79+
80+
WORKFLOWS, ACTIVITIES_PER_WORKFLOW, SECONDS_PER_ACTIVITY — amount and duration of load

custom_worker_tuner/__init__.py

Whitespace-only changes.

custom_worker_tuner/downstream.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""A static-capacity counter representing a downstream resource.
2+
3+
This is the thing the supplier gates on. It pretends to be a database
4+
connection pool, an HTTP rate limiter, an external quota — anything
5+
finite. We just track an integer.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import logging
11+
import threading
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
class Downstream:
17+
"""A counter with a fixed capacity. Thread-safe."""
18+
19+
def __init__(self, allowed_connections: int, name: str = "downstream") -> None:
20+
self.allowed_connections = allowed_connections
21+
self.name = name
22+
self.currently_connected = 0
23+
self.connection_pool = threading.Lock()
24+
logger.info("Downstream ready: name=%s allowed_connections=%d", name, allowed_connections)
25+
26+
def increment(self) -> bool:
27+
"""allow one connection. Returns False if at capacity."""
28+
with self.connection_pool:
29+
if self.currently_connected >= self.allowed_connections:
30+
return False
31+
self.currently_connected += 1
32+
return True
33+
34+
def decrement(self) -> None:
35+
"""Release one slot. Floored at 0 so a buggy caller can't go negative."""
36+
with self.connection_pool:
37+
self.currently_connected = max(0, self.currently_connected - 1)

custom_worker_tuner/shared.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Workflow + activity shared between worker.py and starter.py."""
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
from dataclasses import dataclass
7+
from datetime import timedelta
8+
9+
from temporalio import activity, workflow
10+
11+
TASK_QUEUE = "custom-worker-tuner"
12+
13+
14+
@dataclass
15+
class BatchInput:
16+
activities: int
17+
seconds: float
18+
19+
20+
@activity.defn
21+
async def do_work(seconds: float) -> None:
22+
"""Sleep, simulating an I/O-bound activity."""
23+
await asyncio.sleep(seconds)
24+
25+
@workflow.defn
26+
class RunBatch:
27+
"""Runs N do_work activities in parallel."""
28+
29+
@workflow.run
30+
async def run(self, inp: BatchInput) -> None:
31+
await asyncio.gather(
32+
*(
33+
workflow.execute_activity(
34+
do_work,
35+
inp.seconds,
36+
start_to_close_timeout=timedelta(minutes=2),
37+
)
38+
for _ in range(inp.activities)
39+
)
40+
)

custom_worker_tuner/starter.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""Drives load against the worker so you can see the supplier gating.
2+
3+
Run from samples-python/::
4+
5+
python -m custom_worker_tuner.starter
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import asyncio
11+
import time
12+
import uuid
13+
14+
from temporalio.client import Client
15+
from temporalio.envconfig import ClientConfig
16+
17+
from custom_worker_tuner.shared import TASK_QUEUE, BatchInput, RunBatch
18+
19+
# Tweak these to push more or less load.
20+
WORKFLOWS = 10
21+
ACTIVITIES_PER_WORKFLOW = 20
22+
SECONDS_PER_ACTIVITY = 2.0
23+
24+
25+
async def main() -> None:
26+
config = ClientConfig.load_client_connect_config()
27+
config.setdefault("target_host", "localhost:7233")
28+
client = await Client.connect(**config)
29+
run_id = uuid.uuid4().hex[:8]
30+
inp = BatchInput(activities=ACTIVITIES_PER_WORKFLOW, seconds=SECONDS_PER_ACTIVITY)
31+
total = WORKFLOWS * ACTIVITIES_PER_WORKFLOW
32+
33+
print(f"starting {WORKFLOWS} workflows × {ACTIVITIES_PER_WORKFLOW} activities × {SECONDS_PER_ACTIVITY}s")
34+
t0 = time.perf_counter()
35+
36+
handles = await asyncio.gather(
37+
*(
38+
client.start_workflow(
39+
RunBatch.run,
40+
inp,
41+
id=f"batch-{run_id}-{i}",
42+
task_queue=TASK_QUEUE,
43+
)
44+
for i in range(WORKFLOWS)
45+
)
46+
)
47+
await asyncio.gather(*(h.result() for h in handles))
48+
49+
wall = time.perf_counter() - t0
50+
print(f"done in {wall:.1f}s ({total} activities, {total / wall:.0f}/s)")
51+
52+
53+
if __name__ == "__main__":
54+
asyncio.run(main())

custom_worker_tuner/supplier.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""DownstreamAwareSupplier — a CustomSlotSupplier that gates on the downstream.
2+
3+
The supplier owns four methods the Temporal SDK calls:
4+
5+
reserve_slot async, on the asyncio loop — wait until allowed to poll
6+
try_reserve_slot sync, on a Rust thread — fast eager path
7+
mark_slot_used sync, on a Rust thread — task arrived
8+
release_slot sync, on a Rust thread — slot done
9+
10+
Our rule for all four: the worker may only run one activity per unit
11+
of downstream capacity. So:
12+
13+
reserve_slot poll downstream.increment until True
14+
try_reserve_slot one shot at downstream.increment
15+
mark_slot_used nothing (we don't need it)
16+
release_slot downstream.decrement
17+
18+
That's the entire supplier. Everything else is contract polish.
19+
20+
Sample, not production
21+
----------------------
22+
Real production code would also handle: cancellation in the tiny
23+
window between ``increment()`` and returning the permit (one-slot
24+
leak per cancel event), idempotency against a duplicate
25+
``release_slot``, and exception walls so a misbehaving downstream
26+
doesn't put the SDK into a tight retry loop. We skip those here for
27+
clarity. See the README for what to add when you copy this.
28+
"""
29+
30+
from __future__ import annotations
31+
32+
import asyncio
33+
import itertools
34+
import logging
35+
36+
from temporalio.worker import (
37+
CustomSlotSupplier,
38+
SlotMarkUsedContext,
39+
SlotPermit,
40+
SlotReleaseContext,
41+
SlotReserveContext,
42+
)
43+
44+
from custom_worker_tuner.downstream import Downstream
45+
46+
logger = logging.getLogger(__name__)
47+
48+
# A single global counter so every slot grant gets a unique short ID we
49+
# can grep for. itertools.count is atomic under CPython's GIL.
50+
_slot_id_gen = itertools.count(1)
51+
52+
53+
class _Permit(SlotPermit):
54+
"""SlotPermit subclass that just carries a sequential id for logs."""
55+
56+
def __init__(self, slot_id: int) -> None:
57+
super().__init__()
58+
self.slot_id = slot_id
59+
60+
61+
class DownstreamAwareSupplier(CustomSlotSupplier):
62+
def __init__(self, downstream: Downstream, poll_interval_ms: int = 100) -> None:
63+
self.downstream = downstream
64+
self.poll_interval_ms = poll_interval_ms
65+
logger.info(
66+
"DownstreamAwareSupplier ready: downstream=%s poll_interval_ms=%d",
67+
downstream.name,
68+
poll_interval_ms,
69+
)
70+
71+
async def reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit:
72+
"""block downstream until it has capacity to get incremented and then grant a slot."""
73+
slot_id = next(_slot_id_gen)
74+
while not self.downstream.increment():
75+
await asyncio.sleep(self.poll_interval_ms / 1000.0)
76+
self._log("reserve", slot_id, "ready to poll")
77+
return _Permit(slot_id)
78+
79+
def try_reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit | None:
80+
"""Eager path: can i run this activity right now?"""
81+
if self.downstream.increment():
82+
slot_id = next(_slot_id_gen)
83+
self._log("reserve", slot_id, "eager dispatch")
84+
return _Permit(slot_id)
85+
return None
86+
87+
def mark_slot_used(self, ctx: SlotMarkUsedContext) -> None:
88+
"""A task arrived for a reserved slot"""
89+
slot_id = getattr(ctx.permit, "slot_id", "?")
90+
self._log("used", slot_id, "activity running")
91+
92+
def release_slot(self, ctx: SlotReleaseContext) -> None:
93+
"""Return the slot to the downstream."""
94+
slot_id = getattr(ctx.permit, "slot_id", "?")
95+
# ctx.slot_info is None when the poll timed out — the slot was
96+
# reserved but no task ever arrived. Surface it so it's not
97+
# confused with a normal completion.
98+
detail = "no task arrived" if ctx.slot_info is None else "activity done"
99+
self.downstream.decrement()
100+
self._log("release", slot_id, detail)
101+
102+
# ----- internals -----
103+
104+
def _log(self, event: str, slot_id, note: str) -> None:
105+
"""Emit one line in the column format::
106+
107+
EVENT SLOT COUNT NOTE
108+
reserve #209 10/10
109+
wait #210 10/10 full
110+
"""
111+
count = f"{self.downstream.currently_connected}/{self.downstream.allowed_connections}"
112+
logger.info(f"{event:<8} #{slot_id!s:<4} {count:>5} {note}")

0 commit comments

Comments
 (0)