22
33import asyncio
44import sys
5- from collections import deque
6- from typing import TYPE_CHECKING , Final , Generic , cast
5+ from dataclasses import dataclass
6+ from typing import TYPE_CHECKING , Generic , Literal , cast
77from typing_extensions import Any , TypeVar , final , override
88
9- from duron ._core .ops import Barrier , StreamCreate , create_op
9+ from duron ._core .ops import StreamCreate , create_op
1010from duron ._core .stream import OpWriter
1111
1212if TYPE_CHECKING :
@@ -36,7 +36,10 @@ def __repr__(self) -> str:
3636 return f"SignalInterrupt(value={ self .value !r} )"
3737
3838
39- _SIGNAL_TRIGGER : Final = object ()
39+ @dataclass (slots = True )
40+ class _SignalState :
41+ depth : int
42+ triggered : Literal [False ] | SignalInterrupt
4043
4144
4245@final
@@ -58,22 +61,19 @@ class Signal(Generic[_T]):
5861
5962 def __init__ (self , loop : EventLoop ) -> None :
6063 self ._loop = loop
61- # task -> [offset, stack depth]
62- self ._tasks : dict [asyncio .Task [Any ], tuple [int , int ]] = {}
63- self ._trigger : deque [tuple [int , _T ]] = deque ()
64+ self ._tasks : dict [asyncio .Task [Any ], _SignalState ] = {}
6465
6566 async def __aenter__ (self ) -> None :
6667 task = asyncio .current_task ()
6768 if task is None :
6869 return
6970 assert task .get_loop () == self ._loop
70- offset , _ = await create_op (self ._loop , Barrier ())
71- for toffset , value in self ._trigger :
72- if toffset > offset :
73- raise SignalInterrupt (value = value )
74- _ , depth = self ._tasks .get (task , (0 , - 1 ))
75- self ._tasks [task ] = (offset , depth + 1 )
76- self ._flush ()
71+ if task not in self ._tasks :
72+ val = _SignalState (depth = 0 , triggered = False )
73+ self ._tasks [task ] = val
74+ else :
75+ val = self ._tasks [task ]
76+ val .depth += 1
7777
7878 async def __aexit__ (
7979 self ,
@@ -84,40 +84,27 @@ async def __aexit__(
8484 task = asyncio .current_task ()
8585 if task is None :
8686 return
87- offset_end , _ = await create_op (self ._loop , Barrier ())
88-
89- offset_start , depth = self ._tasks .pop (task )
90- if depth > 0 :
91- self ._tasks [task ] = (offset_end , depth - 1 )
92- for toffset , value in self ._trigger :
93- if (
94- offset_start < toffset < offset_end
95- and exc_type is asyncio .CancelledError
96- and (args := cast ("asyncio.CancelledError" , exc_value ).args )
97- and args [0 ] is _SIGNAL_TRIGGER
98- ):
99- if sys .version_info >= (3 , 11 ):
100- _ = task .uncancel ()
101- self ._flush ()
102- raise SignalInterrupt (value = value )
103-
104- def on_next (self , offset : int , value : _T ) -> None :
105- self ._trigger .append ((offset , value ))
106- for t , (toffset , _depth ) in self ._tasks .items ():
107- if toffset < offset :
108- _ = self ._loop .call_soon (t .cancel , _SIGNAL_TRIGGER )
87+ state = self ._tasks .pop (task )
88+ _ = self ._loop .generate_op_scope ()
89+ triggered = state .triggered
90+ if state .depth > 0 :
91+ state .triggered = False
92+ state .depth -= 1
93+ self ._tasks [task ] = state
94+ if triggered is not False :
95+ if sys .version_info >= (3 , 11 ):
96+ _ = task .uncancel ()
97+ raise triggered from None
98+
99+ def on_next (self , _offset : int , value : _T ) -> None :
100+ for t , state in self ._tasks .items ():
101+ if state .triggered is False :
102+ state .triggered = SignalInterrupt (value = value )
103+ _ = self ._loop .call_soon (t .cancel , state .triggered )
109104
110105 def on_close (self , _offset : int , _exc : Exception | None ) -> None :
111106 pass
112107
113- def _flush (self ) -> None :
114- if not self ._tasks :
115- self ._trigger .clear ()
116- return
117- min_offset = min ((offset for offset , _ in self ._tasks .values ()))
118- while self ._trigger and self ._trigger [0 ][0 ] < min_offset :
119- _ = self ._trigger .popleft ()
120-
121108
122109async def create_signal (
123110 loop : EventLoop , dtype : TypeHint [_T ], name : str | None , metadata : OpMetadata
0 commit comments