Skip to content

Commit 58ef6bb

Browse files
committed
feat: unify signal and stream writer
1 parent 4d699af commit 58ef6bb

File tree

4 files changed

+103
-58
lines changed

4 files changed

+103
-58
lines changed

src/duron/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from duron._core.invoke import invoke as invoke
55
from duron._core.signal import Signal as Signal
66
from duron._core.signal import SignalInterrupt as SignalInterrupt
7-
from duron._core.signal import SignalWriter as SignalWriter
87
from duron._core.stream import Stream as Stream
98
from duron._core.stream import StreamClosed as StreamClosed
109
from duron._core.stream import StreamOp as StreamOp

src/duron/_core/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from collections.abc import Awaitable, Callable, Coroutine, Mapping
3232
from contextlib import AbstractAsyncContextManager
3333

34-
from duron._core.signal import Signal, SignalWriter
34+
from duron._core.signal import Signal
3535
from duron._core.stream import Stream, StreamWriter
3636
from duron._loop import EventLoop
3737
from duron.typing import TypeHint
@@ -189,7 +189,7 @@ async def create_signal(
189189
*,
190190
name: str | None = None,
191191
labels: Mapping[str, str] | None = None,
192-
) -> tuple[Signal[_T], SignalWriter[_T]]:
192+
) -> tuple[Signal[_T], StreamWriter[_T]]:
193193
"""Create a new signal within the context.
194194
195195
Args:

src/duron/_core/signal.py

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
import asyncio
44
import sys
5-
from asyncio.exceptions import CancelledError
65
from collections import deque
7-
from typing import TYPE_CHECKING, Generic, cast
8-
from typing_extensions import Any, Protocol, TypeVar, final
6+
from typing import TYPE_CHECKING, Final, Generic, cast
7+
from typing_extensions import Any, TypeVar, final, override
98

109
from duron._core.ops import Barrier, StreamClose, StreamCreate, StreamEmit, create_op
1110
from duron._loop import wrap_future
@@ -14,10 +13,11 @@
1413
from types import TracebackType
1514

1615
from duron._core.ops import OpAnnotations
16+
from duron._core.stream import StreamWriter
1717
from duron._loop import EventLoop
1818
from duron.typing._hint import TypeHint
1919

20-
_In = TypeVar("_In", contravariant=True) # noqa: PLC0105
20+
_InT = TypeVar("_InT", contravariant=True) # noqa: PLC0105
2121

2222

2323
class SignalInterrupt(Exception): # noqa: N818
@@ -27,53 +27,47 @@ class SignalInterrupt(Exception): # noqa: N818
2727
value: The value passed to the signal trigger that caused the interrupt.
2828
"""
2929

30-
def __init__(self, *args: object, value: object) -> None:
31-
super().__init__(*args)
32-
self.value: object = value
30+
def __init__(self, value: object) -> None:
31+
super().__init__()
32+
self.value = value
3333

34-
35-
class SignalWriter(Protocol, Generic[_In]):
36-
"""Protocol for writing values to a signal to interrupt operations."""
37-
38-
async def trigger(self, value: _In, /) -> None:
39-
"""Trigger the signal with a value, interrupting active operations.
40-
41-
Args:
42-
value: The value to send with the interrupt.
43-
"""
44-
...
45-
46-
async def close(self, /) -> None:
47-
"""Close the signal stream, preventing further triggers."""
48-
...
34+
@override
35+
def __repr__(self) -> str:
36+
return f"SignalInterrupt(value={self.value!r})"
4937

5038

5139
@final
52-
class _Writer(Generic[_In]):
40+
class SignalWriter(Generic[_InT]):
41+
"""Object for writing values to a signal to interrupt operations."""
42+
5343
__slots__ = ("_loop", "_stream_id")
5444

5545
def __init__(self, stream_id: str, loop: EventLoop) -> None:
5646
self._stream_id = stream_id
5747
self._loop = loop
5848

59-
async def trigger(self, value: _In, /) -> None:
49+
async def send(self, value: _InT) -> None:
50+
"""Trigger the signal with a value, interrupting active operations.
51+
52+
Args:
53+
value: The value to send with the interrupt.
54+
"""
6055
await wrap_future(
6156
create_op(self._loop, StreamEmit(stream_id=self._stream_id, value=value))
6257
)
6358

64-
async def close(self, /) -> None:
59+
async def close(self, exc: Exception | None = None) -> None:
60+
"""Close the signal stream, preventing further triggers."""
6561
await wrap_future(
66-
create_op(
67-
self._loop, StreamClose(stream_id=self._stream_id, exception=None)
68-
)
62+
create_op(self._loop, StreamClose(stream_id=self._stream_id, exception=exc))
6963
)
7064

7165

72-
_SENTINAL = object()
66+
_SIGNAL_TRIGGER: Final = object()
7367

7468

7569
@final
76-
class Signal(Generic[_In]):
70+
class Signal(Generic[_InT]):
7771
"""Signal context manager for interruptible operations.
7872
7973
Signal provides a mechanism for interrupting in-progress operations. When used
@@ -91,9 +85,9 @@ class Signal(Generic[_In]):
9185

9286
def __init__(self, loop: EventLoop) -> None:
9387
self._loop = loop
94-
# task -> [offset, refcnt]
88+
# task -> [offset, stack depth]
9589
self._tasks: dict[asyncio.Task[Any], tuple[int, int]] = {}
96-
self._trigger: deque[tuple[int, _In]] = deque()
90+
self._trigger: deque[tuple[int, _InT]] = deque()
9791

9892
async def __aenter__(self) -> None:
9993
task = asyncio.current_task()
@@ -104,8 +98,8 @@ async def __aenter__(self) -> None:
10498
for toffset, value in self._trigger:
10599
if toffset > offset:
106100
raise SignalInterrupt(value=value)
107-
_, refcnt = self._tasks.get(task, (0, 0))
108-
self._tasks[task] = (offset, refcnt + 1)
101+
_, depth = self._tasks.get(task, (0, -1))
102+
self._tasks[task] = (offset, depth + 1)
109103
self._flush()
110104

111105
async def __aexit__(
@@ -117,43 +111,49 @@ async def __aexit__(
117111
task = asyncio.current_task()
118112
if task is None:
119113
return
120-
offset_start, refcnt = self._tasks.pop(task)
121114
offset_end = await create_op(self._loop, Barrier())
122-
if refcnt > 1:
123-
self._tasks[task] = (offset_end, refcnt - 1)
115+
116+
offset_start, depth = self._tasks.pop(task)
117+
if depth > 0:
118+
self._tasks[task] = (offset_end, depth - 1)
124119
for toffset, value in self._trigger:
125-
if offset_start < toffset < offset_end:
126-
if sys.version_info >= (3, 11) and exc_type is CancelledError:
127-
assert exc_value # noqa: S101
128-
assert exc_value.args[0] is _SENTINAL # noqa: S101
120+
if (
121+
offset_start < toffset < offset_end
122+
and exc_type is asyncio.CancelledError
123+
and (args := cast("asyncio.CancelledError", exc_value).args)
124+
and args[0] is _SIGNAL_TRIGGER
125+
):
126+
if sys.version_info >= (3, 11):
129127
_ = task.uncancel()
128+
self._flush()
130129
raise SignalInterrupt(value=value)
131130

132-
def on_next(self, offset: int, value: _In) -> None:
131+
def on_next(self, offset: int, value: _InT) -> None:
133132
self._trigger.append((offset, value))
134-
for t, (toffset, _refcnt) in self._tasks.items():
133+
for t, (toffset, _depth) in self._tasks.items():
135134
if toffset < offset:
136-
_ = self._loop.call_soon(t.cancel, _SENTINAL)
135+
_ = self._loop.call_soon(t.cancel, _SIGNAL_TRIGGER)
137136

138137
def on_close(self, _offset: int, _exc: Exception | None) -> None:
139138
pass
140139

141140
def _flush(self) -> None:
142-
assert len(self._tasks) > 0 # noqa: S101
143-
min_offset = min(offset for offset, _ in self._tasks.values())
141+
if not self._tasks:
142+
self._trigger.clear()
143+
return
144+
min_offset = min((offset for offset, _ in self._tasks.values()))
144145
while self._trigger and self._trigger[0][0] < min_offset:
145146
_ = self._trigger.popleft()
146147

147148

148149
async def create_signal(
149-
loop: EventLoop, dtype: TypeHint[_In], annotations: OpAnnotations
150-
) -> tuple[Signal[_In], SignalWriter[_In]]:
151-
assert asyncio.get_running_loop() is loop # noqa: S101
152-
s: Signal[_In] = Signal(loop)
150+
loop: EventLoop, dtype: TypeHint[_InT], annotations: OpAnnotations
151+
) -> tuple[Signal[_InT], StreamWriter[_InT]]:
152+
s: Signal[_InT] = Signal(loop)
153153
sid = await create_op(
154154
loop,
155155
StreamCreate(
156156
dtype=dtype, observer=cast("Signal[object]", s), annotations=annotations
157157
),
158158
)
159-
return (s, _Writer(sid, loop))
159+
return (s, SignalWriter(sid, loop))

tests/test_signal.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import contextlib
45
from typing import cast
56

67
import pytest
78

8-
from duron import Context, SignalInterrupt, durable, invoke
9+
from duron import Context, Signal, SignalInterrupt, durable, invoke
910
from duron.contrib.storage import MemoryLogStorage
11+
from duron.typing._hint import Provided
1012

1113

1214
@pytest.mark.asyncio
@@ -15,13 +17,13 @@ async def test_signal() -> None:
1517
async def activity(ctx: Context) -> list[int]:
1618
signal, handle = await ctx.create_signal(int)
1719

18-
await handle.trigger(1)
20+
await handle.send(1)
1921

2022
async def trigger() -> None:
2123
await asyncio.sleep(0.1)
22-
await handle.trigger(2)
24+
await handle.send(2)
2325
await asyncio.sleep(0.1)
24-
await handle.trigger(3)
26+
await handle.send(3)
2527

2628
t = asyncio.create_task(trigger())
2729

@@ -49,3 +51,47 @@ async def trigger() -> None:
4951
async with invoke(activity, log) as t:
5052
await t.resume()
5153
assert await t.wait() == [2, 3]
54+
55+
56+
@pytest.mark.asyncio
57+
async def test_signal_timing() -> None:
58+
@durable()
59+
async def activity(ctx: Context, s: Signal[int] = Provided) -> list[list[int]]:
60+
async def tracker(signal: Signal[int], t: float) -> list[int]:
61+
values: list[int] = []
62+
while True:
63+
await asyncio.sleep(t)
64+
try:
65+
async with signal:
66+
await asyncio.sleep(9999)
67+
except SignalInterrupt as e:
68+
values.append(cast("int", e.value))
69+
if len(values) > 10:
70+
return values
71+
72+
rnd = ctx.random()
73+
return await asyncio.gather(*[
74+
(tracker(s, rnd.random() * 0.01)) for _ in range(4)
75+
])
76+
77+
log = MemoryLogStorage()
78+
async with invoke(activity, log) as t:
79+
signal = t.open_stream("s", "w")
80+
await t.start()
81+
82+
async def push() -> None:
83+
i = 0
84+
while True:
85+
i += 1
86+
await asyncio.sleep(0.001)
87+
await signal.send(i)
88+
89+
pusher = asyncio.create_task(push())
90+
a = await t.wait()
91+
pusher.cancel()
92+
with contextlib.suppress(asyncio.CancelledError):
93+
await pusher
94+
95+
async with invoke(activity, log) as t:
96+
await t.resume()
97+
assert await t.wait() == a

0 commit comments

Comments
 (0)