Skip to content

Commit aea8daa

Browse files
committed
feat: add stream timing test
1 parent 821000f commit aea8daa

File tree

7 files changed

+218
-173
lines changed

7 files changed

+218
-173
lines changed

examples/agent.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,14 @@ async def run_round(user_input: str) -> None:
137137
# Main conversation loop (max 100 rounds)
138138
for i in range(100):
139139
# Collect any queued messages without waiting
140-
msgs: list[str] = list(await input_.next_nowait())
140+
it = await input_.next_nowait()
141+
msgs: list[str] = list(it) if it is not None else []
141142

142143
# If no queued messages, wait for next input
143144
if not msgs:
145+
await output.send(("assistant", "[Waiting on user input...]"))
144146
m = await input_.next()
145-
msgs = [m]
147+
msgs.extend(m)
146148

147149
# Execute round with tracing and interruption support
148150
with span(f"Round #{i + 1}"):

src/duron/_core/session.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -339,10 +339,10 @@ async def _resume(self) -> bool:
339339
_ = await self._step()
340340
if is_entry(entry):
341341
if entry["source"] == "task":
342-
recvd_msgs.add(entry["id"])
343342
if not self._handle_message(o, entry):
344-
msg = "Extra messages found in log"
345-
raise RuntimeError(msg)
343+
e = "Extra messages found in log"
344+
raise RuntimeError(e)
345+
recvd_msgs.add(entry["id"])
346346
else:
347347
_ = self._handle_message(o, entry)
348348
_ = await self._step()
@@ -353,10 +353,19 @@ async def _resume(self) -> bool:
353353
self._pending_msg.pop()
354354
recvd_msgs.remove(id_)
355355

356+
pending: list[Entry] = []
357+
for msg in self._pending_msg:
358+
if msg["id"] not in recvd_msgs:
359+
pending.append(msg)
360+
else:
361+
recvd_msgs.remove(msg["id"])
362+
self._pending_msg = pending
363+
356364
if len(recvd_msgs) > 0:
357-
msg = "Extra messages found in log"
358-
raise RuntimeError(msg)
359-
return self._main.done()
365+
e = "Extra messages found in log"
366+
raise RuntimeError(e)
367+
368+
return self._main.done() and len(self._pending_msg) == 0
360369

361370
async def _start(self) -> None:
362371
if await self._resume():
@@ -599,6 +608,7 @@ def done(f: OpFuture) -> None:
599608
tracer.end_op_span(op.future_id, promise_complete_entry)
600609
await self._enqueue_log(promise_complete_entry)
601610
self._loop.post_completion(id_, result=None)
611+
602612
case _:
603613
assert_never(op)
604614

src/duron/_core/signal.py

Lines changed: 31 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import asyncio
44
import sys
5-
from collections import deque
6-
from typing import TYPE_CHECKING, Final, Generic, cast
5+
from dataclasses import dataclass
6+
from typing import TYPE_CHECKING, Generic, Literal, cast
77
from typing_extensions import Any, TypeVar, final, override
88

9-
from duron._core.ops import Barrier, StreamCreate, create_op
9+
from duron._core.ops import StreamCreate, create_op
1010
from duron._core.stream import OpWriter
1111

1212
if TYPE_CHECKING:
@@ -36,7 +36,10 @@ def __repr__(self) -> str:
3636
return f"SignalInterrupt(value={self.value!r})"
3737

3838

39-
_SIGNAL_TRIGGER: Final = object()
39+
@dataclass(slots=True)
40+
class _SignalState:
41+
depth: int
42+
triggered: Literal[False] | SignalInterrupt
4043

4144

4245
@final
@@ -58,22 +61,19 @@ class Signal(Generic[_T]):
5861

5962
def __init__(self, loop: EventLoop) -> None:
6063
self._loop = loop
61-
# task -> [offset, stack depth]
62-
self._tasks: dict[asyncio.Task[Any], tuple[int, int]] = {}
63-
self._trigger: deque[tuple[int, _T]] = deque()
64+
self._tasks: dict[asyncio.Task[Any], _SignalState] = {}
6465

6566
async def __aenter__(self) -> None:
6667
task = asyncio.current_task()
6768
if task is None:
6869
return
6970
assert task.get_loop() == self._loop
70-
offset, _ = await create_op(self._loop, Barrier())
71-
for toffset, value in self._trigger:
72-
if toffset > offset:
73-
raise SignalInterrupt(value=value)
74-
_, depth = self._tasks.get(task, (0, -1))
75-
self._tasks[task] = (offset, depth + 1)
76-
self._flush()
71+
if task not in self._tasks:
72+
val = _SignalState(depth=0, triggered=False)
73+
self._tasks[task] = val
74+
else:
75+
val = self._tasks[task]
76+
val.depth += 1
7777

7878
async def __aexit__(
7979
self,
@@ -84,40 +84,27 @@ async def __aexit__(
8484
task = asyncio.current_task()
8585
if task is None:
8686
return
87-
offset_end, _ = await create_op(self._loop, Barrier())
88-
89-
offset_start, depth = self._tasks.pop(task)
90-
if depth > 0:
91-
self._tasks[task] = (offset_end, depth - 1)
92-
for toffset, value in self._trigger:
93-
if (
94-
offset_start < toffset < offset_end
95-
and exc_type is asyncio.CancelledError
96-
and (args := cast("asyncio.CancelledError", exc_value).args)
97-
and args[0] is _SIGNAL_TRIGGER
98-
):
99-
if sys.version_info >= (3, 11):
100-
_ = task.uncancel()
101-
self._flush()
102-
raise SignalInterrupt(value=value)
103-
104-
def on_next(self, offset: int, value: _T) -> None:
105-
self._trigger.append((offset, value))
106-
for t, (toffset, _depth) in self._tasks.items():
107-
if toffset < offset:
108-
_ = self._loop.call_soon(t.cancel, _SIGNAL_TRIGGER)
87+
state = self._tasks.pop(task)
88+
_ = self._loop.generate_op_scope()
89+
triggered = state.triggered
90+
if state.depth > 0:
91+
state.triggered = False
92+
state.depth -= 1
93+
self._tasks[task] = state
94+
if triggered is not False:
95+
if sys.version_info >= (3, 11):
96+
_ = task.uncancel()
97+
raise triggered from None
98+
99+
def on_next(self, _offset: int, value: _T) -> None:
100+
for t, state in self._tasks.items():
101+
if state.triggered is False:
102+
state.triggered = SignalInterrupt(value=value)
103+
_ = self._loop.call_soon(t.cancel, state.triggered)
109104

110105
def on_close(self, _offset: int, _exc: Exception | None) -> None:
111106
pass
112107

113-
def _flush(self) -> None:
114-
if not self._tasks:
115-
self._trigger.clear()
116-
return
117-
min_offset = min((offset for offset, _ in self._tasks.values()))
118-
while self._trigger and self._trigger[0][0] < min_offset:
119-
_ = self._trigger.popleft()
120-
121108

122109
async def create_signal(
123110
loop: EventLoop, dtype: TypeHint[_T], name: str | None, metadata: OpMetadata

0 commit comments

Comments
 (0)