22
33import asyncio
44import sys
5- from asyncio .exceptions import CancelledError
65from collections import deque
7- from typing import TYPE_CHECKING , Generic , cast
8- from typing_extensions import Any , Protocol , TypeVar , final
6+ from typing import TYPE_CHECKING , Final , Generic , cast
7+ from typing_extensions import Any , TypeVar , final , override
98
109from duron ._core .ops import Barrier , StreamClose , StreamCreate , StreamEmit , create_op
1110from duron ._loop import wrap_future
1413 from types import TracebackType
1514
1615 from duron ._core .ops import OpAnnotations
16+ from duron ._core .stream import StreamWriter
1717 from duron ._loop import EventLoop
1818 from duron .typing ._hint import TypeHint
1919
20- _In = TypeVar ("_In " , contravariant = True ) # noqa: PLC0105
20+ _InT = TypeVar ("_InT " , contravariant = True ) # noqa: PLC0105
2121
2222
2323class SignalInterrupt (Exception ): # noqa: N818
@@ -27,53 +27,47 @@ class SignalInterrupt(Exception): # noqa: N818
2727 value: The value passed to the signal trigger that caused the interrupt.
2828 """
2929
30- def __init__ (self , * args : object , value : object ) -> None :
31- super ().__init__ (* args )
32- self .value : object = value
30+ def __init__ (self , value : object ) -> None :
31+ super ().__init__ ()
32+ self .value = value
3333
34-
35- class SignalWriter (Protocol , Generic [_In ]):
36- """Protocol for writing values to a signal to interrupt operations."""
37-
38- async def trigger (self , value : _In , / ) -> None :
39- """Trigger the signal with a value, interrupting active operations.
40-
41- Args:
42- value: The value to send with the interrupt.
43- """
44- ...
45-
46- async def close (self , / ) -> None :
47- """Close the signal stream, preventing further triggers."""
48- ...
34+ @override
35+ def __repr__ (self ) -> str :
36+ return f"SignalInterrupt(value={ self .value !r} )"
4937
5038
5139@final
52- class _Writer (Generic [_In ]):
40+ class SignalWriter (Generic [_InT ]):
41+ """Object for writing values to a signal to interrupt operations."""
42+
5343 __slots__ = ("_loop" , "_stream_id" )
5444
5545 def __init__ (self , stream_id : str , loop : EventLoop ) -> None :
5646 self ._stream_id = stream_id
5747 self ._loop = loop
5848
59- async def trigger (self , value : _In , / ) -> None :
49+ async def send (self , value : _InT ) -> None :
50+ """Trigger the signal with a value, interrupting active operations.
51+
52+ Args:
53+ value: The value to send with the interrupt.
54+ """
6055 await wrap_future (
6156 create_op (self ._loop , StreamEmit (stream_id = self ._stream_id , value = value ))
6257 )
6358
64- async def close (self , / ) -> None :
59+ async def close (self , exc : Exception | None = None ) -> None :
60+ """Close the signal stream, preventing further triggers."""
6561 await wrap_future (
66- create_op (
67- self ._loop , StreamClose (stream_id = self ._stream_id , exception = None )
68- )
62+ create_op (self ._loop , StreamClose (stream_id = self ._stream_id , exception = exc ))
6963 )
7064
7165
72- _SENTINAL = object ()
66+ _SIGNAL_TRIGGER : Final = object ()
7367
7468
7569@final
76- class Signal (Generic [_In ]):
70+ class Signal (Generic [_InT ]):
7771 """Signal context manager for interruptible operations.
7872
7973 Signal provides a mechanism for interrupting in-progress operations. When used
@@ -91,9 +85,9 @@ class Signal(Generic[_In]):
9185
9286 def __init__ (self , loop : EventLoop ) -> None :
9387 self ._loop = loop
94- # task -> [offset, refcnt ]
88+ # task -> [offset, stack depth ]
9589 self ._tasks : dict [asyncio .Task [Any ], tuple [int , int ]] = {}
96- self ._trigger : deque [tuple [int , _In ]] = deque ()
90+ self ._trigger : deque [tuple [int , _InT ]] = deque ()
9791
9892 async def __aenter__ (self ) -> None :
9993 task = asyncio .current_task ()
@@ -104,8 +98,8 @@ async def __aenter__(self) -> None:
10498 for toffset , value in self ._trigger :
10599 if toffset > offset :
106100 raise SignalInterrupt (value = value )
107- _ , refcnt = self ._tasks .get (task , (0 , 0 ))
108- self ._tasks [task ] = (offset , refcnt + 1 )
101+ _ , depth = self ._tasks .get (task , (0 , - 1 ))
102+ self ._tasks [task ] = (offset , depth + 1 )
109103 self ._flush ()
110104
111105 async def __aexit__ (
@@ -117,43 +111,49 @@ async def __aexit__(
117111 task = asyncio .current_task ()
118112 if task is None :
119113 return
120- offset_start , refcnt = self ._tasks .pop (task )
121114 offset_end = await create_op (self ._loop , Barrier ())
122- if refcnt > 1 :
123- self ._tasks [task ] = (offset_end , refcnt - 1 )
115+
116+ offset_start , depth = self ._tasks .pop (task )
117+ if depth > 0 :
118+ self ._tasks [task ] = (offset_end , depth - 1 )
124119 for toffset , value in self ._trigger :
125- if offset_start < toffset < offset_end :
126- if sys .version_info >= (3 , 11 ) and exc_type is CancelledError :
127- assert exc_value # noqa: S101
128- assert exc_value .args [0 ] is _SENTINAL # noqa: S101
120+ if (
121+ offset_start < toffset < offset_end
122+ and exc_type is asyncio .CancelledError
123+ and (args := cast ("asyncio.CancelledError" , exc_value ).args )
124+ and args [0 ] is _SIGNAL_TRIGGER
125+ ):
126+ if sys .version_info >= (3 , 11 ):
129127 _ = task .uncancel ()
128+ self ._flush ()
130129 raise SignalInterrupt (value = value )
131130
132- def on_next (self , offset : int , value : _In ) -> None :
131+ def on_next (self , offset : int , value : _InT ) -> None :
133132 self ._trigger .append ((offset , value ))
134- for t , (toffset , _refcnt ) in self ._tasks .items ():
133+ for t , (toffset , _depth ) in self ._tasks .items ():
135134 if toffset < offset :
136- _ = self ._loop .call_soon (t .cancel , _SENTINAL )
135+ _ = self ._loop .call_soon (t .cancel , _SIGNAL_TRIGGER )
137136
138137 def on_close (self , _offset : int , _exc : Exception | None ) -> None :
139138 pass
140139
141140 def _flush (self ) -> None :
142- assert len (self ._tasks ) > 0 # noqa: S101
143- min_offset = min (offset for offset , _ in self ._tasks .values ())
141+ if not self ._tasks :
142+ self ._trigger .clear ()
143+ return
144+ min_offset = min ((offset for offset , _ in self ._tasks .values ()))
144145 while self ._trigger and self ._trigger [0 ][0 ] < min_offset :
145146 _ = self ._trigger .popleft ()
146147
147148
148149async def create_signal (
149- loop : EventLoop , dtype : TypeHint [_In ], annotations : OpAnnotations
150- ) -> tuple [Signal [_In ], SignalWriter [_In ]]:
151- assert asyncio .get_running_loop () is loop # noqa: S101
152- s : Signal [_In ] = Signal (loop )
150+ loop : EventLoop , dtype : TypeHint [_InT ], annotations : OpAnnotations
151+ ) -> tuple [Signal [_InT ], StreamWriter [_InT ]]:
152+ s : Signal [_InT ] = Signal (loop )
153153 sid = await create_op (
154154 loop ,
155155 StreamCreate (
156156 dtype = dtype , observer = cast ("Signal[object]" , s ), annotations = annotations
157157 ),
158158 )
159- return (s , _Writer (sid , loop ))
159+ return (s , SignalWriter (sid , loop ))
0 commit comments