Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions custom_worker_tuner/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Custom Worker Tuner

A `CustomSlotSupplier` is a sample that lets you gate slot grants on whatever you want.
This sample gates on a fake DB pool: the worker only polls for a new
activity when the pool has a free connection.

**Note:** This sample is illustrative only. It shouldn't be used for production grade use-cases.

## What this sample is
db_pool.py - A fixed-capacity fake pool backed by a `BoundedSemaphore`. Two methods: `acquire(blocking=True)` (claim a slot, returns False if full when non-blocking), `release()` (return a slot)
supplier.py - The custom slot supplier. `reserve_slot` blocks on `connection_pool.acquire()` until a slot is free; `try_reserve_slot` does the same non-blocking. `release_slot` calls `connection_pool.release()`
shared.py - A RunBatch workflow that runs N do_work activities in parallel. The activity just sleeps
worker.py - Wires `FakeDatabaseConnectionPool` + `PoolSlotSupplier` into a WorkerTuner
starter.py - Drives load

The flow:

When the pool is at capacity, `reserve_slot` blocks until a
connection frees up. The excess work piles up on the Temporal server, not
inside the worker.

## Run

In three terminals from `samples-python/`:

```bash
temporal server start-dev # terminal 1
uv run custom_worker_tuner/worker.py # terminal 2
uv run custom_worker_tuner/starter.py # terminal 3
```

## What you'll see

The worker prints one line per slot lifecycle event:

```

TIME EVENT SLOT COUNT DETAIL
────────────────────────────────────────────────────────────
10:31:49.842 reserve #1 1/10 ready to poll
10:31:49.842 reserve #2 2/10 ready to poll
10:31:49.843 reserve #3 3/10 ready to poll
10:31:49.843 reserve #4 4/10 ready to poll
10:31:49.843 reserve #5 5/10 ready to poll
10:31:49.843 reserve #6 6/10 ready to poll
10:31:56.763 reserve #7 7/10 eager dispatch
10:31:56.763 reserve #8 8/10 eager dispatch
10:31:56.764 reserve #9 9/10 eager dispatch
10:31:56.766 reserve #10 10/10 eager dispatch
10:31:56.767 release #7 9/10 no task arrived
10:31:56.768 release #8 8/10 no task arrived
10:31:56.768 release #9 7/10 no task arrived
10:31:56.768 reserve #11 8/10 eager dispatch
10:31:56.768 reserve #12 9/10 eager dispatch
10:31:56.768 reserve #13 10/10 eager dispatch
10:31:56.771 used #1 10/10 activity running
10:31:56.771 release #10 9/10 no task arrived
```

Under load, with more activities than capacity, COUNT pins at
10/10 — that's the supplier refusing to poll past the gate.
we chose 10 because default there are 5 pollers for python sdk

## Knobs

worker.py:

CAPACITY — pool capacity (the gate)

starter.py:

WORKFLOWS, ACTIVITIES_PER_WORKFLOW, SECONDS_PER_ACTIVITY — amount and duration of load
Empty file.
33 changes: 33 additions & 0 deletions custom_worker_tuner/db_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

import logging
import threading

logger = logging.getLogger(__name__)


class FakeDatabaseConnectionPool:
"""Pretend connection pool with a fixed capacity, backed by a Semaphore."""

def __init__(self, allowed_connections: int, name: str = "db") -> None:
self.allowed_connections = allowed_connections
self.name = name
self._connection_pool = threading.BoundedSemaphore(allowed_connections)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You want to use asyncio.Semaphore... that's why you have do do the nasty asyncio.to_thread thing

logger.info(
"FakeDatabaseConnectionPool ready: name=%s allowed_connections=%d",
name,
allowed_connections,
)

def acquire(self, blocking: bool = True) -> bool:
"""Claim a connection. When blocking, waits until one is free."""
return self._connection_pool.acquire(blocking=blocking)

def release(self) -> None:
"""Return a connection to the pool."""
self._connection_pool.release()

@property
def in_use(self) -> int:
"""Derived from the semaphore — single source of truth."""
return self.allowed_connections - self._connection_pool._value
39 changes: 39 additions & 0 deletions custom_worker_tuner/shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

import asyncio
from dataclasses import dataclass
from datetime import timedelta

from temporalio import activity, workflow

TASK_QUEUE = "custom-worker-tuner"


@dataclass
class BatchInput:
activities: int
seconds: float


@activity.defn
async def do_work(seconds: float) -> None:
"""Sleep, simulating an I/O-bound activity."""
await asyncio.sleep(seconds)


@workflow.defn
class RunBatch:
"""Runs N do_work activities in parallel."""

@workflow.run
async def run(self, inp: BatchInput) -> None:
await asyncio.gather(
*(
workflow.execute_activity(
do_work,
inp.seconds,
start_to_close_timeout=timedelta(minutes=2),
)
for _ in range(inp.activities)
)
)
49 changes: 49 additions & 0 deletions custom_worker_tuner/starter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

import asyncio
import time
import uuid

from temporalio.client import Client
from temporalio.envconfig import ClientConfig

from custom_worker_tuner.shared import TASK_QUEUE, BatchInput, RunBatch

# Tweak these to push more or less load.
WORKFLOWS = 10
ACTIVITIES_PER_WORKFLOW = 20
SECONDS_PER_ACTIVITY = 2.0


async def main() -> None:
config = ClientConfig.load_client_connect_config()
config.setdefault("target_host", "localhost:7233")
client = await Client.connect(**config)
run_id = uuid.uuid4().hex[:8]
inp = BatchInput(activities=ACTIVITIES_PER_WORKFLOW, seconds=SECONDS_PER_ACTIVITY)
total = WORKFLOWS * ACTIVITIES_PER_WORKFLOW

print(
f"starting {WORKFLOWS} workflows × {ACTIVITIES_PER_WORKFLOW} activities × {SECONDS_PER_ACTIVITY}s"
)
t0 = time.perf_counter()

handles = await asyncio.gather(
*(
client.start_workflow(
RunBatch.run,
inp,
id=f"batch-{run_id}-{i}",
task_queue=TASK_QUEUE,
)
for i in range(WORKFLOWS)
)
)
await asyncio.gather(*(h.result() for h in handles))

wall = time.perf_counter() - t0
print(f"done in {wall:.1f}s ({total} activities, {total / wall:.0f}/s)")


if __name__ == "__main__":
asyncio.run(main())
64 changes: 64 additions & 0 deletions custom_worker_tuner/supplier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

import asyncio
import itertools
import logging

from temporalio.worker import (
CustomSlotSupplier,
SlotMarkUsedContext,
SlotPermit,
SlotReleaseContext,
SlotReserveContext,
)

from custom_worker_tuner.db_pool import FakeDatabaseConnectionPool

logger = logging.getLogger(__name__)

_slot_id_gen = itertools.count(1)


class _Permit(SlotPermit):
"""SlotPermit subclass that just carries a sequential id for logs."""

def __init__(self, slot_id: int) -> None:
super().__init__()
self.slot_id = slot_id


class PoolSlotSupplier(CustomSlotSupplier):
"""Hands out slots only when the backing pool has a free connection."""

def __init__(self, connection_pool: FakeDatabaseConnectionPool) -> None:
self.connection_pool = connection_pool
logger.info("PoolSlotSupplier ready: connection_pool=%s", connection_pool.name)

async def reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit:
"""Block until the pool has capacity, then grant a slot."""
await asyncio.to_thread(self.connection_pool.acquire)
slot_id = next(_slot_id_gen)
self._log("reserve", slot_id, "ready to poll")
return _Permit(slot_id)

def try_reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit | None:
"""Eager path: try to claim a slot without blocking."""
if self.connection_pool.acquire(blocking=False):
slot_id = next(_slot_id_gen)
self._log("reserve", slot_id, "eager dispatch")
return _Permit(slot_id)
return None

def mark_slot_used(self, ctx: SlotMarkUsedContext) -> None:
slot_id = getattr(ctx.permit, "slot_id", "?")
self._log("used", slot_id, "activity running")

def release_slot(self, ctx: SlotReleaseContext) -> None:
slot_id = getattr(ctx.permit, "slot_id", "?")
detail = "no task arrived" if ctx.slot_info is None else "activity done"
self.connection_pool.release()
self._log("release", slot_id, detail)

def _log(self, event: str, slot_id, note: str) -> None:
count = f"{self.connection_pool.in_use}/{self.connection_pool.allowed_connections}"
logger.info(f"{event:<8} {count:>5} {note}")
56 changes: 56 additions & 0 deletions custom_worker_tuner/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations

import asyncio
import logging

from temporalio.client import Client
from temporalio.envconfig import ClientConfig
from temporalio.worker import FixedSizeSlotSupplier, Worker, WorkerTuner

from custom_worker_tuner.db_pool import FakeDatabaseConnectionPool
from custom_worker_tuner.shared import TASK_QUEUE, RunBatch, do_work
from custom_worker_tuner.supplier import PoolSlotSupplier

CAPACITY = 10 # number of pool connections (and concurrent activities)
LOG_LEVEL = "INFO"


async def main() -> None:
logging.basicConfig(
level=getattr(logging, LOG_LEVEL.upper(), logging.INFO),
format="%(asctime)s.%(msecs)03d %(message)s",
datefmt="%H:%M:%S",
)

config = ClientConfig.load_client_connect_config()
config.setdefault("target_host", "localhost:7233")
client = await Client.connect(**config)

pool = FakeDatabaseConnectionPool(allowed_connections=CAPACITY, name="db")
supplier = PoolSlotSupplier(pool)
tuner = WorkerTuner.create_composite(
workflow_supplier=FixedSizeSlotSupplier(100),
activity_supplier=supplier,
local_activity_supplier=FixedSizeSlotSupplier(100),
nexus_supplier=FixedSizeSlotSupplier(100),
)

worker = Worker(
client,
task_queue=TASK_QUEUE,
workflows=[RunBatch],
activities=[do_work],
tuner=tuner,
)

print(f"\nworker started — capacity={CAPACITY}\n")
print("TIME EVENT COUNT DETAIL")
print("─" * 60)
await worker.run()


if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
pass
Loading