Skip to content

Commit 2c05e40

Browse files
Add custom_worker_tuner sample
1 parent 4d453de commit 2c05e40

7 files changed

Lines changed: 320 additions & 0 deletions

File tree

custom_worker_tuner/README.md

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Custom Worker Tuner
2+
3+
A `CustomSlotSupplier` is a sample that lets you gate slot grants on whatever you want.
4+
This sample gates on a fake DB pool: the worker only polls for a new
5+
activity when the pool has a free connection.
6+
7+
## What this sample is
8+
downstream.py - A static-capacity counter. Pretends to be a DB pool. Two methods: increment() (claim a slot, returns False if full), decrement() (release)
9+
supplier.py - The custom slot supplier. On reserve_slot it polls downstream.increment() until it succeeds. On release_slot it calls downstream.decrement()
10+
shared.py - A RunBatch workflow that runs N do_work activities in parallel. The activity just sleeps
11+
worker.py - Wires Downstream + DownstreamAwareSupplier into a WorkerTuner
12+
starter.py - Drives load
13+
14+
The flow:
15+
16+
When the downstream is at capacity, `reserve_slot` blocks until a
17+
slot frees up. The excess work piles up on the Temporal server, not
18+
inside the worker.
19+
20+
## Run
21+
22+
In three terminals from `samples-python/`:
23+
24+
```bash
25+
temporal server start-dev # terminal 1
26+
uv run custom_worker_tuner/worker.py # terminal 2
27+
uv run custom_worker_tuner/starter.py # terminal 3
28+
```
29+
30+
## What you'll see
31+
32+
The worker prints one line per slot lifecycle event:
33+
34+
```
35+
36+
TIME EVENT SLOT COUNT DETAIL
37+
────────────────────────────────────────────────────────────
38+
10:31:49.842 reserve #1 1/10 ready to poll
39+
10:31:49.842 reserve #2 2/10 ready to poll
40+
10:31:49.843 reserve #3 3/10 ready to poll
41+
10:31:49.843 reserve #4 4/10 ready to poll
42+
10:31:49.843 reserve #5 5/10 ready to poll
43+
10:31:49.843 reserve #6 6/10 ready to poll
44+
10:31:56.763 reserve #7 7/10 eager dispatch
45+
10:31:56.763 reserve #8 8/10 eager dispatch
46+
10:31:56.764 reserve #9 9/10 eager dispatch
47+
10:31:56.766 reserve #10 10/10 eager dispatch
48+
10:31:56.767 release #7 9/10 no task arrived
49+
10:31:56.768 release #8 8/10 no task arrived
50+
10:31:56.768 release #9 7/10 no task arrived
51+
10:31:56.768 reserve #11 8/10 eager dispatch
52+
10:31:56.768 reserve #12 9/10 eager dispatch
53+
10:31:56.768 reserve #13 10/10 eager dispatch
54+
10:31:56.771 used #1 10/10 activity running
55+
10:31:56.771 release #10 9/10 no task arrived
56+
```
57+
58+
Under load, with more activities than capacity, COUNT pins at
59+
10/10 — that's the supplier refusing to poll past the gate.
60+
we chose 10 because default there are 5 pollers for python sdk
61+
62+
## Knobs
63+
64+
worker.py:
65+
66+
CAPACITY — downstream capacity (the gate)
67+
POLL_INTERVAL_MS — how often the supplier rechecks when full
68+
69+
starter.py:
70+
71+
WORKFLOWS, ACTIVITIES_PER_WORKFLOW, SECONDS_PER_ACTIVITY — amount and duration of load

custom_worker_tuner/__init__.py

Whitespace-only changes.

custom_worker_tuner/downstream.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
import threading
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
class Downstream:
10+
"""A counter with a fixed capacity. Thread-safe."""
11+
12+
def __init__(self, allowed_connections: int, name: str = "downstream") -> None:
13+
self.allowed_connections = allowed_connections
14+
self.name = name
15+
self.currently_connected = 0
16+
self.connection_pool = threading.Lock()
17+
logger.info(
18+
"Downstream ready: name=%s allowed_connections=%d",
19+
name,
20+
allowed_connections,
21+
)
22+
23+
def increment(self) -> bool:
24+
"""allow one connection. Returns False if at capacity."""
25+
with self.connection_pool:
26+
if self.currently_connected >= self.allowed_connections:
27+
return False
28+
self.currently_connected += 1
29+
return True
30+
31+
def decrement(self) -> None:
32+
"""Release one slot. Floored at 0 so a buggy caller can't go negative."""
33+
with self.connection_pool:
34+
self.currently_connected = max(0, self.currently_connected - 1)

custom_worker_tuner/shared.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from dataclasses import dataclass
5+
from datetime import timedelta
6+
7+
from temporalio import activity, workflow
8+
9+
TASK_QUEUE = "custom-worker-tuner"
10+
11+
12+
@dataclass
13+
class BatchInput:
14+
activities: int
15+
seconds: float
16+
17+
18+
@activity.defn
19+
async def do_work(seconds: float) -> None:
20+
"""Sleep, simulating an I/O-bound activity."""
21+
await asyncio.sleep(seconds)
22+
23+
24+
@workflow.defn
25+
class RunBatch:
26+
"""Runs N do_work activities in parallel."""
27+
28+
@workflow.run
29+
async def run(self, inp: BatchInput) -> None:
30+
await asyncio.gather(
31+
*(
32+
workflow.execute_activity(
33+
do_work,
34+
inp.seconds,
35+
start_to_close_timeout=timedelta(minutes=2),
36+
)
37+
for _ in range(inp.activities)
38+
)
39+
)

custom_worker_tuner/starter.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import time
5+
import uuid
6+
7+
from temporalio.client import Client
8+
from temporalio.envconfig import ClientConfig
9+
10+
from custom_worker_tuner.shared import TASK_QUEUE, BatchInput, RunBatch
11+
12+
# Tweak these to push more or less load.
13+
WORKFLOWS = 10
14+
ACTIVITIES_PER_WORKFLOW = 20
15+
SECONDS_PER_ACTIVITY = 2.0
16+
17+
18+
async def main() -> None:
19+
config = ClientConfig.load_client_connect_config()
20+
config.setdefault("target_host", "localhost:7233")
21+
client = await Client.connect(**config)
22+
run_id = uuid.uuid4().hex[:8]
23+
inp = BatchInput(activities=ACTIVITIES_PER_WORKFLOW, seconds=SECONDS_PER_ACTIVITY)
24+
total = WORKFLOWS * ACTIVITIES_PER_WORKFLOW
25+
26+
print(
27+
f"starting {WORKFLOWS} workflows × {ACTIVITIES_PER_WORKFLOW} activities × {SECONDS_PER_ACTIVITY}s"
28+
)
29+
t0 = time.perf_counter()
30+
31+
handles = await asyncio.gather(
32+
*(
33+
client.start_workflow(
34+
RunBatch.run,
35+
inp,
36+
id=f"batch-{run_id}-{i}",
37+
task_queue=TASK_QUEUE,
38+
)
39+
for i in range(WORKFLOWS)
40+
)
41+
)
42+
await asyncio.gather(*(h.result() for h in handles))
43+
44+
wall = time.perf_counter() - t0
45+
print(f"done in {wall:.1f}s ({total} activities, {total / wall:.0f}/s)")
46+
47+
48+
if __name__ == "__main__":
49+
asyncio.run(main())

custom_worker_tuner/supplier.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import itertools
5+
import logging
6+
7+
from temporalio.worker import (
8+
CustomSlotSupplier,
9+
SlotMarkUsedContext,
10+
SlotPermit,
11+
SlotReleaseContext,
12+
SlotReserveContext,
13+
)
14+
15+
from custom_worker_tuner.downstream import Downstream
16+
17+
logger = logging.getLogger(__name__)
18+
19+
_slot_id_gen = itertools.count(1)
20+
21+
22+
class _Permit(SlotPermit):
23+
"""SlotPermit subclass that just carries a sequential id for logs."""
24+
25+
def __init__(self, slot_id: int) -> None:
26+
super().__init__()
27+
self.slot_id = slot_id
28+
29+
30+
class DownstreamAwareSupplier(CustomSlotSupplier):
31+
def __init__(self, downstream: Downstream, poll_interval_ms: int = 100) -> None:
32+
self.downstream = downstream
33+
self.poll_interval_ms = poll_interval_ms
34+
logger.info(
35+
"DownstreamAwareSupplier ready: downstream=%s poll_interval_ms=%d",
36+
downstream.name,
37+
poll_interval_ms,
38+
)
39+
40+
async def reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit:
41+
"""block downstream until it has capacity to get incremented and then grant a slot."""
42+
slot_id = next(_slot_id_gen)
43+
while not self.downstream.increment():
44+
await asyncio.sleep(self.poll_interval_ms / 1000.0)
45+
self._log("reserve", slot_id, "ready to poll")
46+
return _Permit(slot_id)
47+
48+
def try_reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit | None:
49+
"""Eager path: can i run this activity right now?"""
50+
if self.downstream.increment():
51+
slot_id = next(_slot_id_gen)
52+
self._log("reserve", slot_id, "eager dispatch")
53+
return _Permit(slot_id)
54+
return None
55+
56+
def mark_slot_used(self, ctx: SlotMarkUsedContext) -> None:
57+
"""A task arrived for a reserved slot"""
58+
slot_id = getattr(ctx.permit, "slot_id", "?")
59+
self._log("used", slot_id, "activity running")
60+
61+
def release_slot(self, ctx: SlotReleaseContext) -> None:
62+
"""Return the slot to the downstream."""
63+
slot_id = getattr(ctx.permit, "slot_id", "?")
64+
detail = "no task arrived" if ctx.slot_info is None else "activity done"
65+
self.downstream.decrement()
66+
self._log("release", slot_id, detail)
67+
68+
def _log(self, event: str, slot_id, note: str) -> None:
69+
count = f"{self.downstream.currently_connected}/{self.downstream.allowed_connections}"
70+
logger.info(f"{event:<8} {count:>5} {note}")

custom_worker_tuner/worker.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import logging
5+
6+
from temporalio.client import Client
7+
from temporalio.envconfig import ClientConfig
8+
from temporalio.worker import FixedSizeSlotSupplier, Worker, WorkerTuner
9+
10+
from custom_worker_tuner.downstream import Downstream
11+
from custom_worker_tuner.shared import TASK_QUEUE, RunBatch, do_work
12+
from custom_worker_tuner.supplier import DownstreamAwareSupplier
13+
14+
CAPACITY = 10 # number of connections allowed at a time
15+
POLL_INTERVAL_MS = 500
16+
LOG_LEVEL = "INFO" # flip to "DEBUG" to see every increment/decrement
17+
18+
19+
async def main() -> None:
20+
logging.basicConfig(
21+
level=getattr(logging, LOG_LEVEL.upper(), logging.INFO),
22+
format="%(asctime)s.%(msecs)03d %(message)s",
23+
datefmt="%H:%M:%S",
24+
)
25+
26+
config = ClientConfig.load_client_connect_config()
27+
config.setdefault("target_host", "localhost:7233")
28+
client = await Client.connect(**config)
29+
30+
downstream = Downstream(allowed_connections=CAPACITY, name="db")
31+
supplier = DownstreamAwareSupplier(downstream, poll_interval_ms=POLL_INTERVAL_MS)
32+
tuner = WorkerTuner.create_composite(
33+
workflow_supplier=FixedSizeSlotSupplier(100),
34+
activity_supplier=supplier,
35+
local_activity_supplier=FixedSizeSlotSupplier(100),
36+
nexus_supplier=FixedSizeSlotSupplier(100),
37+
)
38+
39+
worker = Worker(
40+
client,
41+
task_queue=TASK_QUEUE,
42+
workflows=[RunBatch],
43+
activities=[do_work],
44+
tuner=tuner,
45+
)
46+
47+
print(f"\nworker started — capacity={CAPACITY}, poll={POLL_INTERVAL_MS}ms\n")
48+
print("TIME EVENT COUNT DETAIL")
49+
print("─" * 60)
50+
await worker.run()
51+
52+
53+
if __name__ == "__main__":
54+
try:
55+
asyncio.run(main())
56+
except KeyboardInterrupt:
57+
pass

0 commit comments

Comments
 (0)