Skip to content

Commit 1ce1294

Browse files
Address review feedback
1 parent 2c05e40 commit 1ce1294

5 files changed

Lines changed: 54 additions & 62 deletions

File tree

custom_worker_tuner/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ This sample gates on a fake DB pool: the worker only polls for a new
55
activity when the pool has a free connection.
66

77
## What this sample is
8-
downstream.py - A static-capacity counter. Pretends to be a DB pool. Two methods: increment() (claim a slot, returns False if full), decrement() (release)
8+
db_pool.py - A static-capacity counter. Pretends to be a DB pool. Two methods: increment() (claim a slot, returns False if full), decrement() (release)
99
supplier.py - The custom slot supplier. On reserve_slot it polls downstream.increment() until it succeeds. On release_slot it calls downstream.decrement()
1010
shared.py - A RunBatch workflow that runs N do_work activities in parallel. The activity just sleeps
1111
worker.py - Wires Downstream + DownstreamAwareSupplier into a WorkerTuner

custom_worker_tuner/db_pool.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
import threading
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
class FakeDatabaseConnectionPool:
10+
"""Pretend connection pool with a fixed capacity, backed by a Semaphore."""
11+
12+
def __init__(self, allowed_connections: int, name: str = "db") -> None:
13+
self.allowed_connections = allowed_connections
14+
self.name = name
15+
self._connection_pool = threading.BoundedSemaphore(allowed_connections)
16+
logger.info(
17+
"FakeDatabaseConnectionPool ready: name=%s allowed_connections=%d",
18+
name,
19+
allowed_connections,
20+
)
21+
22+
def acquire(self, blocking: bool = True) -> bool:
23+
"""Claim a connection. When blocking, waits until one is free."""
24+
return self._connection_pool.acquire(blocking=blocking)
25+
26+
def release(self) -> None:
27+
"""Return a connection to the pool."""
28+
self._connection_pool.release()
29+
30+
@property
31+
def in_use(self) -> int:
32+
"""Derived from the semaphore — single source of truth."""
33+
return self.allowed_connections - self._connection_pool._value

custom_worker_tuner/downstream.py

Lines changed: 0 additions & 34 deletions
This file was deleted.

custom_worker_tuner/supplier.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
SlotReserveContext,
1313
)
1414

15-
from custom_worker_tuner.downstream import Downstream
15+
from custom_worker_tuner.db_pool import FakeDatabaseConnectionPool
1616

1717
logger = logging.getLogger(__name__)
1818

@@ -27,44 +27,38 @@ def __init__(self, slot_id: int) -> None:
2727
self.slot_id = slot_id
2828

2929

30-
class DownstreamAwareSupplier(CustomSlotSupplier):
31-
def __init__(self, downstream: Downstream, poll_interval_ms: int = 100) -> None:
32-
self.downstream = downstream
33-
self.poll_interval_ms = poll_interval_ms
34-
logger.info(
35-
"DownstreamAwareSupplier ready: downstream=%s poll_interval_ms=%d",
36-
downstream.name,
37-
poll_interval_ms,
38-
)
30+
class PoolSlotSupplier(CustomSlotSupplier):
31+
"""Hands out slots only when the backing pool has a free connection."""
32+
33+
def __init__(self, connection_pool: FakeDatabaseConnectionPool) -> None:
34+
self.connection_pool = connection_pool
35+
logger.info("PoolSlotSupplier ready: connection_pool=%s", connection_pool.name)
3936

4037
async def reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit:
41-
"""block downstream until it has capacity to get incremented and then grant a slot."""
38+
"""Block until the pool has capacity, then grant a slot."""
39+
await asyncio.to_thread(self.connection_pool.acquire)
4240
slot_id = next(_slot_id_gen)
43-
while not self.downstream.increment():
44-
await asyncio.sleep(self.poll_interval_ms / 1000.0)
4541
self._log("reserve", slot_id, "ready to poll")
4642
return _Permit(slot_id)
4743

4844
def try_reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit | None:
49-
"""Eager path: can i run this activity right now?"""
50-
if self.downstream.increment():
45+
"""Eager path: try to claim a slot without blocking."""
46+
if self.connection_pool.acquire(blocking=False):
5147
slot_id = next(_slot_id_gen)
5248
self._log("reserve", slot_id, "eager dispatch")
5349
return _Permit(slot_id)
5450
return None
5551

5652
def mark_slot_used(self, ctx: SlotMarkUsedContext) -> None:
57-
"""A task arrived for a reserved slot"""
5853
slot_id = getattr(ctx.permit, "slot_id", "?")
5954
self._log("used", slot_id, "activity running")
6055

6156
def release_slot(self, ctx: SlotReleaseContext) -> None:
62-
"""Return the slot to the downstream."""
6357
slot_id = getattr(ctx.permit, "slot_id", "?")
6458
detail = "no task arrived" if ctx.slot_info is None else "activity done"
65-
self.downstream.decrement()
59+
self.connection_pool.release()
6660
self._log("release", slot_id, detail)
6761

6862
def _log(self, event: str, slot_id, note: str) -> None:
69-
count = f"{self.downstream.currently_connected}/{self.downstream.allowed_connections}"
63+
count = f"{self.connection_pool.in_use}/{self.connection_pool.allowed_connections}"
7064
logger.info(f"{event:<8} {count:>5} {note}")

custom_worker_tuner/worker.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77
from temporalio.envconfig import ClientConfig
88
from temporalio.worker import FixedSizeSlotSupplier, Worker, WorkerTuner
99

10-
from custom_worker_tuner.downstream import Downstream
10+
from custom_worker_tuner.db_pool import FakeDatabaseConnectionPool
1111
from custom_worker_tuner.shared import TASK_QUEUE, RunBatch, do_work
12-
from custom_worker_tuner.supplier import DownstreamAwareSupplier
12+
from custom_worker_tuner.supplier import PoolSlotSupplier
1313

14-
CAPACITY = 10 # number of connections allowed at a time
15-
POLL_INTERVAL_MS = 500
16-
LOG_LEVEL = "INFO" # flip to "DEBUG" to see every increment/decrement
14+
CAPACITY = 10 # number of pool connections (and concurrent activities)
15+
LOG_LEVEL = "INFO"
1716

1817

1918
async def main() -> None:
@@ -27,8 +26,8 @@ async def main() -> None:
2726
config.setdefault("target_host", "localhost:7233")
2827
client = await Client.connect(**config)
2928

30-
downstream = Downstream(allowed_connections=CAPACITY, name="db")
31-
supplier = DownstreamAwareSupplier(downstream, poll_interval_ms=POLL_INTERVAL_MS)
29+
pool = FakeDatabaseConnectionPool(allowed_connections=CAPACITY, name="db")
30+
supplier = PoolSlotSupplier(pool)
3231
tuner = WorkerTuner.create_composite(
3332
workflow_supplier=FixedSizeSlotSupplier(100),
3433
activity_supplier=supplier,
@@ -44,7 +43,7 @@ async def main() -> None:
4443
tuner=tuner,
4544
)
4645

47-
print(f"\nworker started — capacity={CAPACITY}, poll={POLL_INTERVAL_MS}ms\n")
46+
print(f"\nworker started — capacity={CAPACITY}\n")
4847
print("TIME EVENT COUNT DETAIL")
4948
print("─" * 60)
5049
await worker.run()

0 commit comments

Comments
 (0)