-
Notifications
You must be signed in to change notification settings - Fork 108
sample for custom worker tuner #314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
deepika-awasthi
wants to merge
3
commits into
main
Choose a base branch
from
deepika/custom-worker-tuner
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| # Custom Worker Tuner | ||
|
|
||
| A `CustomSlotSupplier` is a sample that lets you gate slot grants on whatever you want. | ||
| This sample gates on a fake DB pool: the worker only polls for a new | ||
| activity when the pool has a free connection. | ||
|
|
||
| **Note:** This sample is illustrative only. It shouldn't be used for production grade use-cases. | ||
|
|
||
| ## What this sample is | ||
| db_pool.py - A fixed-capacity fake pool backed by a `BoundedSemaphore`. Two methods: `acquire(blocking=True)` (claim a slot, returns False if full when non-blocking), `release()` (return a slot) | ||
| supplier.py - The custom slot supplier. `reserve_slot` blocks on `connection_pool.acquire()` until a slot is free; `try_reserve_slot` does the same non-blocking. `release_slot` calls `connection_pool.release()` | ||
| shared.py - A RunBatch workflow that runs N do_work activities in parallel. The activity just sleeps | ||
| worker.py - Wires `FakeDatabaseConnectionPool` + `PoolSlotSupplier` into a WorkerTuner | ||
| starter.py - Drives load | ||
|
|
||
| The flow: | ||
|
|
||
| When the pool is at capacity, `reserve_slot` blocks until a | ||
| connection frees up. The excess work piles up on the Temporal server, not | ||
| inside the worker. | ||
|
|
||
| ## Run | ||
|
|
||
| In three terminals from `samples-python/`: | ||
|
|
||
| ```bash | ||
| temporal server start-dev # terminal 1 | ||
| uv run custom_worker_tuner/worker.py # terminal 2 | ||
| uv run custom_worker_tuner/starter.py # terminal 3 | ||
| ``` | ||
|
|
||
| ## What you'll see | ||
|
|
||
| The worker prints one line per slot lifecycle event: | ||
|
|
||
| ``` | ||
|
|
||
| TIME EVENT SLOT COUNT DETAIL | ||
| ──────────────────────────────────────────────────────────── | ||
| 10:31:49.842 reserve #1 1/10 ready to poll | ||
| 10:31:49.842 reserve #2 2/10 ready to poll | ||
| 10:31:49.843 reserve #3 3/10 ready to poll | ||
| 10:31:49.843 reserve #4 4/10 ready to poll | ||
| 10:31:49.843 reserve #5 5/10 ready to poll | ||
| 10:31:49.843 reserve #6 6/10 ready to poll | ||
| 10:31:56.763 reserve #7 7/10 eager dispatch | ||
| 10:31:56.763 reserve #8 8/10 eager dispatch | ||
| 10:31:56.764 reserve #9 9/10 eager dispatch | ||
| 10:31:56.766 reserve #10 10/10 eager dispatch | ||
| 10:31:56.767 release #7 9/10 no task arrived | ||
| 10:31:56.768 release #8 8/10 no task arrived | ||
| 10:31:56.768 release #9 7/10 no task arrived | ||
| 10:31:56.768 reserve #11 8/10 eager dispatch | ||
| 10:31:56.768 reserve #12 9/10 eager dispatch | ||
| 10:31:56.768 reserve #13 10/10 eager dispatch | ||
| 10:31:56.771 used #1 10/10 activity running | ||
| 10:31:56.771 release #10 9/10 no task arrived | ||
| ``` | ||
|
|
||
| Under load, with more activities than capacity, COUNT pins at | ||
| 10/10 — that's the supplier refusing to poll past the gate. | ||
| we chose 10 because default there are 5 pollers for python sdk | ||
|
|
||
| ## Knobs | ||
|
|
||
| worker.py: | ||
|
|
||
| CAPACITY — pool capacity (the gate) | ||
|
|
||
| starter.py: | ||
|
|
||
| WORKFLOWS, ACTIVITIES_PER_WORKFLOW, SECONDS_PER_ACTIVITY — amount and duration of load |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import logging | ||
| import threading | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class FakeDatabaseConnectionPool: | ||
| """Pretend connection pool with a fixed capacity, backed by a Semaphore.""" | ||
|
|
||
| def __init__(self, allowed_connections: int, name: str = "db") -> None: | ||
| self.allowed_connections = allowed_connections | ||
| self.name = name | ||
| self._connection_pool = threading.BoundedSemaphore(allowed_connections) | ||
| logger.info( | ||
| "FakeDatabaseConnectionPool ready: name=%s allowed_connections=%d", | ||
| name, | ||
| allowed_connections, | ||
| ) | ||
|
|
||
| def acquire(self, blocking: bool = True) -> bool: | ||
| """Claim a connection. When blocking, waits until one is free.""" | ||
| return self._connection_pool.acquire(blocking=blocking) | ||
|
|
||
| def release(self) -> None: | ||
| """Return a connection to the pool.""" | ||
| self._connection_pool.release() | ||
|
|
||
| @property | ||
| def in_use(self) -> int: | ||
| """Derived from the semaphore — single source of truth.""" | ||
| return self.allowed_connections - self._connection_pool._value | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| from dataclasses import dataclass | ||
| from datetime import timedelta | ||
|
|
||
| from temporalio import activity, workflow | ||
|
|
||
| TASK_QUEUE = "custom-worker-tuner" | ||
|
|
||
|
|
||
| @dataclass | ||
| class BatchInput: | ||
| activities: int | ||
| seconds: float | ||
|
|
||
|
|
||
| @activity.defn | ||
| async def do_work(seconds: float) -> None: | ||
| """Sleep, simulating an I/O-bound activity.""" | ||
| await asyncio.sleep(seconds) | ||
|
|
||
|
|
||
| @workflow.defn | ||
| class RunBatch: | ||
| """Runs N do_work activities in parallel.""" | ||
|
|
||
| @workflow.run | ||
| async def run(self, inp: BatchInput) -> None: | ||
| await asyncio.gather( | ||
| *( | ||
| workflow.execute_activity( | ||
| do_work, | ||
| inp.seconds, | ||
| start_to_close_timeout=timedelta(minutes=2), | ||
| ) | ||
| for _ in range(inp.activities) | ||
| ) | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import time | ||
| import uuid | ||
|
|
||
| from temporalio.client import Client | ||
| from temporalio.envconfig import ClientConfig | ||
|
|
||
| from custom_worker_tuner.shared import TASK_QUEUE, BatchInput, RunBatch | ||
|
|
||
| # Tweak these to push more or less load. | ||
| WORKFLOWS = 10 | ||
| ACTIVITIES_PER_WORKFLOW = 20 | ||
| SECONDS_PER_ACTIVITY = 2.0 | ||
|
|
||
|
|
||
| async def main() -> None: | ||
| config = ClientConfig.load_client_connect_config() | ||
| config.setdefault("target_host", "localhost:7233") | ||
| client = await Client.connect(**config) | ||
| run_id = uuid.uuid4().hex[:8] | ||
| inp = BatchInput(activities=ACTIVITIES_PER_WORKFLOW, seconds=SECONDS_PER_ACTIVITY) | ||
| total = WORKFLOWS * ACTIVITIES_PER_WORKFLOW | ||
|
|
||
| print( | ||
| f"starting {WORKFLOWS} workflows × {ACTIVITIES_PER_WORKFLOW} activities × {SECONDS_PER_ACTIVITY}s" | ||
| ) | ||
| t0 = time.perf_counter() | ||
|
|
||
| handles = await asyncio.gather( | ||
| *( | ||
| client.start_workflow( | ||
| RunBatch.run, | ||
| inp, | ||
| id=f"batch-{run_id}-{i}", | ||
| task_queue=TASK_QUEUE, | ||
| ) | ||
| for i in range(WORKFLOWS) | ||
| ) | ||
| ) | ||
| await asyncio.gather(*(h.result() for h in handles)) | ||
|
|
||
| wall = time.perf_counter() - t0 | ||
| print(f"done in {wall:.1f}s ({total} activities, {total / wall:.0f}/s)") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| asyncio.run(main()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import itertools | ||
| import logging | ||
|
|
||
| from temporalio.worker import ( | ||
| CustomSlotSupplier, | ||
| SlotMarkUsedContext, | ||
| SlotPermit, | ||
| SlotReleaseContext, | ||
| SlotReserveContext, | ||
| ) | ||
|
|
||
| from custom_worker_tuner.db_pool import FakeDatabaseConnectionPool | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| _slot_id_gen = itertools.count(1) | ||
|
|
||
|
|
||
| class _Permit(SlotPermit): | ||
| """SlotPermit subclass that just carries a sequential id for logs.""" | ||
|
|
||
| def __init__(self, slot_id: int) -> None: | ||
| super().__init__() | ||
| self.slot_id = slot_id | ||
|
|
||
|
|
||
| class PoolSlotSupplier(CustomSlotSupplier): | ||
| """Hands out slots only when the backing pool has a free connection.""" | ||
|
|
||
| def __init__(self, connection_pool: FakeDatabaseConnectionPool) -> None: | ||
| self.connection_pool = connection_pool | ||
| logger.info("PoolSlotSupplier ready: connection_pool=%s", connection_pool.name) | ||
|
|
||
| async def reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit: | ||
| """Block until the pool has capacity, then grant a slot.""" | ||
| await asyncio.to_thread(self.connection_pool.acquire) | ||
| slot_id = next(_slot_id_gen) | ||
| self._log("reserve", slot_id, "ready to poll") | ||
| return _Permit(slot_id) | ||
|
|
||
| def try_reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit | None: | ||
| """Eager path: try to claim a slot without blocking.""" | ||
| if self.connection_pool.acquire(blocking=False): | ||
| slot_id = next(_slot_id_gen) | ||
| self._log("reserve", slot_id, "eager dispatch") | ||
| return _Permit(slot_id) | ||
| return None | ||
|
|
||
| def mark_slot_used(self, ctx: SlotMarkUsedContext) -> None: | ||
| slot_id = getattr(ctx.permit, "slot_id", "?") | ||
| self._log("used", slot_id, "activity running") | ||
|
|
||
| def release_slot(self, ctx: SlotReleaseContext) -> None: | ||
| slot_id = getattr(ctx.permit, "slot_id", "?") | ||
| detail = "no task arrived" if ctx.slot_info is None else "activity done" | ||
| self.connection_pool.release() | ||
| self._log("release", slot_id, detail) | ||
|
|
||
| def _log(self, event: str, slot_id, note: str) -> None: | ||
| count = f"{self.connection_pool.in_use}/{self.connection_pool.allowed_connections}" | ||
| logger.info(f"{event:<8} {count:>5} {note}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import logging | ||
|
|
||
| from temporalio.client import Client | ||
| from temporalio.envconfig import ClientConfig | ||
| from temporalio.worker import FixedSizeSlotSupplier, Worker, WorkerTuner | ||
|
|
||
| from custom_worker_tuner.db_pool import FakeDatabaseConnectionPool | ||
| from custom_worker_tuner.shared import TASK_QUEUE, RunBatch, do_work | ||
| from custom_worker_tuner.supplier import PoolSlotSupplier | ||
|
|
||
| CAPACITY = 10 # number of pool connections (and concurrent activities) | ||
| LOG_LEVEL = "INFO" | ||
|
|
||
|
|
||
| async def main() -> None: | ||
| logging.basicConfig( | ||
| level=getattr(logging, LOG_LEVEL.upper(), logging.INFO), | ||
| format="%(asctime)s.%(msecs)03d %(message)s", | ||
| datefmt="%H:%M:%S", | ||
| ) | ||
|
|
||
| config = ClientConfig.load_client_connect_config() | ||
| config.setdefault("target_host", "localhost:7233") | ||
| client = await Client.connect(**config) | ||
|
|
||
| pool = FakeDatabaseConnectionPool(allowed_connections=CAPACITY, name="db") | ||
| supplier = PoolSlotSupplier(pool) | ||
| tuner = WorkerTuner.create_composite( | ||
| workflow_supplier=FixedSizeSlotSupplier(100), | ||
| activity_supplier=supplier, | ||
| local_activity_supplier=FixedSizeSlotSupplier(100), | ||
| nexus_supplier=FixedSizeSlotSupplier(100), | ||
| ) | ||
|
|
||
| worker = Worker( | ||
| client, | ||
| task_queue=TASK_QUEUE, | ||
| workflows=[RunBatch], | ||
| activities=[do_work], | ||
| tuner=tuner, | ||
| ) | ||
|
|
||
| print(f"\nworker started — capacity={CAPACITY}\n") | ||
| print("TIME EVENT COUNT DETAIL") | ||
| print("─" * 60) | ||
| await worker.run() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| try: | ||
| asyncio.run(main()) | ||
| except KeyboardInterrupt: | ||
| pass |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You want to use
asyncio.Semaphore... that's why you have do do the nastyasyncio.to_threadthing