Skip to content

Commit c3f6d55

Browse files
Add custom_worker_tuner sample
1 parent 4d453de commit c3f6d55

7 files changed

Lines changed: 327 additions & 0 deletions

File tree

custom_worker_tuner/README.md

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Custom Worker Tuner
2+
3+
A `CustomSlotSupplier` is a sample that lets you gate slot grants on whatever you want.
4+
This sample gates on a fake DB pool: the worker only polls for a new
5+
activity when the pool has a free connection.
6+
7+
## What this sample is
8+
downstream.py - A static-capacity counter. Pretends to be a DB pool. Two methods: increment() (claim a slot, returns False if full), decrement() (release)
9+
supplier.py - The custom slot supplier. On reserve_slot it polls downstream.increment() until it succeeds. On release_slot it calls downstream.decrement()
10+
shared.py - A RunBatch workflow that runs N do_work activities in parallel. The activity just sleeps
11+
worker.py - Wires Downstream + DownstreamAwareSupplier into a WorkerTuner
12+
starter.py - Drives load
13+
14+
The flow:
15+
16+
When the downstream is at capacity, `reserve_slot` blocks until a
17+
slot frees up. The excess work piles up on the Temporal server, not
18+
inside the worker.
19+
20+
## Run
21+
22+
In three terminals from `samples-python/`:
23+
24+
```bash
25+
temporal server start-dev # terminal 1
26+
uv run custom_worker_tuner/worker.py # terminal 2
27+
uv run custom_worker_tuner/starter.py # terminal 3
28+
```
29+
30+
## What you'll see
31+
32+
The worker prints one line per slot lifecycle event:
33+
34+
```
35+
36+
TIME EVENT SLOT COUNT DETAIL
37+
────────────────────────────────────────────────────────────
38+
10:31:49.842 reserve #1 1/10 ready to poll
39+
10:31:49.842 reserve #2 2/10 ready to poll
40+
10:31:49.843 reserve #3 3/10 ready to poll
41+
10:31:49.843 reserve #4 4/10 ready to poll
42+
10:31:49.843 reserve #5 5/10 ready to poll
43+
10:31:49.843 reserve #6 6/10 ready to poll
44+
10:31:56.763 reserve #7 7/10 eager dispatch
45+
10:31:56.763 reserve #8 8/10 eager dispatch
46+
10:31:56.764 reserve #9 9/10 eager dispatch
47+
10:31:56.766 reserve #10 10/10 eager dispatch
48+
10:31:56.767 release #7 9/10 no task arrived
49+
10:31:56.768 release #8 8/10 no task arrived
50+
10:31:56.768 release #9 7/10 no task arrived
51+
10:31:56.768 reserve #11 8/10 eager dispatch
52+
10:31:56.768 reserve #12 9/10 eager dispatch
53+
10:31:56.768 reserve #13 10/10 eager dispatch
54+
10:31:56.771 used #1 10/10 activity running
55+
10:31:56.771 release #10 9/10 no task arrived
56+
```
57+
58+
Under load, with more activities than capacity, COUNT pins at
59+
10/10 — that's the supplier refusing to poll past the gate.
60+
we chose 10 because default there are 5 pollers for python sdk
61+
62+
## Knobs
63+
64+
worker.py:
65+
66+
CAPACITY — downstream capacity (the gate)
67+
POLL_INTERVAL_MS — how often the supplier rechecks when full
68+
69+
starter.py:
70+
71+
WORKFLOWS, ACTIVITIES_PER_WORKFLOW, SECONDS_PER_ACTIVITY — amount and duration of load

custom_worker_tuner/__init__.py

Whitespace-only changes.

custom_worker_tuner/downstream.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
import threading
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
class Downstream:
10+
"""A counter with a fixed capacity. Thread-safe."""
11+
12+
def __init__(self, allowed_connections: int, name: str = "downstream") -> None:
13+
self.allowed_connections = allowed_connections
14+
self.name = name
15+
self.currently_connected = 0
16+
self.connection_pool = threading.Lock()
17+
logger.info("Downstream ready: name=%s allowed_connections=%d", name, allowed_connections)
18+
19+
def increment(self) -> bool:
20+
"""allow one connection. Returns False if at capacity."""
21+
with self.connection_pool:
22+
if self.currently_connected >= self.allowed_connections:
23+
return False
24+
self.currently_connected += 1
25+
return True
26+
27+
def decrement(self) -> None:
28+
"""Release one slot. Floored at 0 so a buggy caller can't go negative."""
29+
with self.connection_pool:
30+
self.currently_connected = max(0, self.currently_connected - 1)

custom_worker_tuner/shared.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from dataclasses import dataclass
5+
from datetime import timedelta
6+
7+
from temporalio import activity, workflow
8+
9+
TASK_QUEUE = "custom-worker-tuner"
10+
11+
12+
@dataclass
13+
class BatchInput:
14+
activities: int
15+
seconds: float
16+
17+
18+
@activity.defn
19+
async def do_work(seconds: float) -> None:
20+
"""Sleep, simulating an I/O-bound activity."""
21+
await asyncio.sleep(seconds)
22+
23+
@workflow.defn
24+
class RunBatch:
25+
"""Runs N do_work activities in parallel."""
26+
27+
@workflow.run
28+
async def run(self, inp: BatchInput) -> None:
29+
await asyncio.gather(
30+
*(
31+
workflow.execute_activity(
32+
do_work,
33+
inp.seconds,
34+
start_to_close_timeout=timedelta(minutes=2),
35+
)
36+
for _ in range(inp.activities)
37+
)
38+
)

custom_worker_tuner/starter.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import time
5+
import uuid
6+
7+
from temporalio.client import Client
8+
from temporalio.envconfig import ClientConfig
9+
10+
from custom_worker_tuner.shared import TASK_QUEUE, BatchInput, RunBatch
11+
12+
# Tweak these to push more or less load.
13+
WORKFLOWS = 10
14+
ACTIVITIES_PER_WORKFLOW = 20
15+
SECONDS_PER_ACTIVITY = 2.0
16+
17+
18+
async def main() -> None:
19+
config = ClientConfig.load_client_connect_config()
20+
config.setdefault("target_host", "localhost:7233")
21+
client = await Client.connect(**config)
22+
run_id = uuid.uuid4().hex[:8]
23+
inp = BatchInput(activities=ACTIVITIES_PER_WORKFLOW, seconds=SECONDS_PER_ACTIVITY)
24+
total = WORKFLOWS * ACTIVITIES_PER_WORKFLOW
25+
26+
print(f"starting {WORKFLOWS} workflows × {ACTIVITIES_PER_WORKFLOW} activities × {SECONDS_PER_ACTIVITY}s")
27+
t0 = time.perf_counter()
28+
29+
handles = await asyncio.gather(
30+
*(
31+
client.start_workflow(
32+
RunBatch.run,
33+
inp,
34+
id=f"batch-{run_id}-{i}",
35+
task_queue=TASK_QUEUE,
36+
)
37+
for i in range(WORKFLOWS)
38+
)
39+
)
40+
await asyncio.gather(*(h.result() for h in handles))
41+
42+
wall = time.perf_counter() - t0
43+
print(f"done in {wall:.1f}s ({total} activities, {total / wall:.0f}/s)")
44+
45+
46+
if __name__ == "__main__":
47+
asyncio.run(main())

custom_worker_tuner/supplier.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import itertools
5+
import logging
6+
7+
from temporalio.worker import (
8+
CustomSlotSupplier,
9+
SlotMarkUsedContext,
10+
SlotPermit,
11+
SlotReleaseContext,
12+
SlotReserveContext,
13+
)
14+
15+
from custom_worker_tuner.downstream import Downstream
16+
17+
logger = logging.getLogger(__name__)
18+
19+
# A single global counter so every slot grant gets a unique short ID we
20+
# can grep for. itertools.count is atomic under CPython's GIL.
21+
_slot_id_gen = itertools.count(1)
22+
23+
24+
class _Permit(SlotPermit):
25+
"""SlotPermit subclass that just carries a sequential id for logs."""
26+
27+
def __init__(self, slot_id: int) -> None:
28+
super().__init__()
29+
self.slot_id = slot_id
30+
31+
32+
class DownstreamAwareSupplier(CustomSlotSupplier):
33+
def __init__(self, downstream: Downstream, poll_interval_ms: int = 100) -> None:
34+
self.downstream = downstream
35+
self.poll_interval_ms = poll_interval_ms
36+
logger.info(
37+
"DownstreamAwareSupplier ready: downstream=%s poll_interval_ms=%d",
38+
downstream.name,
39+
poll_interval_ms,
40+
)
41+
42+
async def reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit:
43+
"""block downstream until it has capacity to get incremented and then grant a slot."""
44+
slot_id = next(_slot_id_gen)
45+
while not self.downstream.increment():
46+
await asyncio.sleep(self.poll_interval_ms / 1000.0)
47+
self._log("reserve", slot_id, "ready to poll")
48+
return _Permit(slot_id)
49+
50+
def try_reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit | None:
51+
"""Eager path: can i run this activity right now?"""
52+
if self.downstream.increment():
53+
slot_id = next(_slot_id_gen)
54+
self._log("reserve", slot_id, "eager dispatch")
55+
return _Permit(slot_id)
56+
return None
57+
58+
def mark_slot_used(self, ctx: SlotMarkUsedContext) -> None:
59+
"""A task arrived for a reserved slot"""
60+
slot_id = getattr(ctx.permit, "slot_id", "?")
61+
self._log("used", slot_id, "activity running")
62+
63+
def release_slot(self, ctx: SlotReleaseContext) -> None:
64+
"""Return the slot to the downstream."""
65+
slot_id = getattr(ctx.permit, "slot_id", "?")
66+
# ctx.slot_info is None when the poll timed out — the slot was
67+
# reserved but no task ever arrived. Surface it so it's not
68+
# confused with a normal completion.
69+
detail = "no task arrived" if ctx.slot_info is None else "activity done"
70+
self.downstream.decrement()
71+
self._log("release", slot_id, detail)
72+
73+
# ----- internals -----
74+
75+
def _log(self, event: str, slot_id, note: str) -> None:
76+
"""Emit one line in the column format::
77+
78+
EVENT SLOT COUNT NOTE
79+
reserve #209 10/10
80+
wait #210 10/10 full
81+
"""
82+
count = f"{self.downstream.currently_connected}/{self.downstream.allowed_connections}"
83+
logger.info(f"{event:<8} #{slot_id!s:<4} {count:>5} {note}")

custom_worker_tuner/worker.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import logging
5+
6+
from temporalio.client import Client
7+
from temporalio.envconfig import ClientConfig
8+
from temporalio.worker import FixedSizeSlotSupplier, Worker, WorkerTuner
9+
10+
from custom_worker_tuner.downstream import Downstream
11+
from custom_worker_tuner.shared import TASK_QUEUE, RunBatch, do_work
12+
from custom_worker_tuner.supplier import DownstreamAwareSupplier
13+
14+
15+
CAPACITY = 10 # number of connections allowed at a time
16+
POLL_INTERVAL_MS = 500
17+
LOG_LEVEL = "INFO" # flip to "DEBUG" to see every increment/decrement
18+
19+
20+
async def main() -> None:
21+
logging.basicConfig(
22+
level=getattr(logging, LOG_LEVEL.upper(), logging.INFO),
23+
format="%(asctime)s.%(msecs)03d %(message)s",
24+
datefmt="%H:%M:%S",
25+
)
26+
27+
config = ClientConfig.load_client_connect_config()
28+
config.setdefault("target_host", "localhost:7233")
29+
client = await Client.connect(**config)
30+
31+
downstream = Downstream(allowed_connections=CAPACITY, name="db")
32+
supplier = DownstreamAwareSupplier(downstream, poll_interval_ms=POLL_INTERVAL_MS)
33+
tuner = WorkerTuner.create_composite(
34+
workflow_supplier=FixedSizeSlotSupplier(100),
35+
activity_supplier=supplier,
36+
local_activity_supplier=FixedSizeSlotSupplier(100),
37+
nexus_supplier=FixedSizeSlotSupplier(100),
38+
)
39+
40+
worker = Worker(
41+
client,
42+
task_queue=TASK_QUEUE,
43+
workflows=[RunBatch],
44+
activities=[do_work],
45+
tuner=tuner,
46+
)
47+
48+
print(f"\nworker started — capacity={CAPACITY}, poll={POLL_INTERVAL_MS}ms\n")
49+
print("TIME EVENT SLOT COUNT DETAIL")
50+
print("─" * 60)
51+
await worker.run()
52+
53+
54+
if __name__ == "__main__":
55+
try:
56+
asyncio.run(main())
57+
except KeyboardInterrupt:
58+
pass

0 commit comments

Comments
 (0)