Skip to content

Commit 906cd96

Browse files
committed
feat: refactor stream
1 parent 58ef6bb commit 906cd96

File tree

13 files changed

+482
-603
lines changed

13 files changed

+482
-603
lines changed

docs/getting-started.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,9 @@ async def main():
245245

246246
await job.start()
247247

248-
async for message in stream:
249-
print(f"Received: {message}")
248+
async with stream as s:
249+
async for message in s:
250+
print(f"Received: {message}")
250251

251252
await job.wait()
252253
```

examples/agent.py

Lines changed: 76 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from duron.tracing import Tracer, span
2626

2727
if TYPE_CHECKING:
28+
from collections.abc import Awaitable
29+
2830
from duron.typing import JSONValue, TypeHint
2931

3032
client = AsyncOpenAI()
@@ -62,73 +64,72 @@ async def agent_fn(
6264
"content": "You are a helpful assistant!",
6365
},
6466
]
65-
async with input_ as inp:
66-
i = 0
67-
while True:
68-
msgs: list[str] = [msgs async for _, msgs in inp.next_nowait(ctx)]
69-
if not msgs:
70-
_, m = await inp.next()
71-
msgs = [m]
72-
73-
history.append({
74-
"role": "user",
75-
"content": "\n".join(msgs),
76-
})
77-
await output.send(("user", "\n".join(msgs)))
78-
with span(f"Round #{i}"):
79-
i += 1
80-
while True:
81-
try:
82-
async with signal:
83-
result = await completion(
84-
ctx,
85-
messages=history,
86-
)
87-
if result.choices[0].message.content:
88-
await output.send((
89-
"assistant",
90-
result.choices[0].message.content,
91-
))
92-
history.append({
93-
"role": "assistant",
94-
"content": result.choices[0].message.content,
95-
"tool_calls": [
96-
{
97-
"id": toolcall.id,
98-
"type": "function",
99-
"function": {
100-
"name": toolcall.function.name,
101-
"arguments": toolcall.function.arguments,
102-
},
103-
}
104-
for toolcall in result.choices[0].message.tool_calls
105-
or []
106-
if toolcall.type == "function"
107-
],
108-
})
109-
if not result.choices[0].message.tool_calls:
110-
break
111-
112-
tasks: list[asyncio.Task[tuple[str, str]]] = []
113-
for tool_call in result.choices[0].message.tool_calls:
114-
await output.send(("call", tool_call.model_dump_json()))
115-
tasks.append(
116-
asyncio.create_task(ctx.run(call_tool, tool_call))
117-
)
118-
for id_, tool_result in await asyncio.gather(*tasks):
119-
await output.send(("tool", tool_result))
120-
history.append({
121-
"role": "tool",
122-
"tool_call_id": id_,
123-
"content": tool_result,
124-
})
125-
except SignalInterrupt:
126-
await output.send(("assistant", "[Interrupted]"))
67+
i = 0
68+
while True:
69+
msgs: list[str] = [msgs async for msgs in input_.next_nowait(ctx)]
70+
if not msgs:
71+
m = await input_.next()
72+
msgs = [m]
73+
74+
history.append({
75+
"role": "user",
76+
"content": "\n".join(msgs),
77+
})
78+
await output.send(("user", "\n".join(msgs)))
79+
with span(f"Round #{i}"):
80+
i += 1
81+
while True:
82+
try:
83+
async with signal:
84+
result = await completion(
85+
ctx,
86+
messages=history,
87+
)
88+
if result.choices[0].message.content:
89+
await output.send((
90+
"assistant",
91+
result.choices[0].message.content,
92+
))
12793
history.append({
12894
"role": "assistant",
129-
"content": "[Interrupted]",
95+
"content": result.choices[0].message.content,
96+
"tool_calls": [
97+
{
98+
"id": toolcall.id,
99+
"type": "function",
100+
"function": {
101+
"name": toolcall.function.name,
102+
"arguments": toolcall.function.arguments,
103+
},
104+
}
105+
for toolcall in result.choices[0].message.tool_calls
106+
or []
107+
if toolcall.type == "function"
108+
],
130109
})
131-
break
110+
if not result.choices[0].message.tool_calls:
111+
break
112+
113+
tasks: list[asyncio.Task[tuple[str, str]]] = []
114+
for tool_call in result.choices[0].message.tool_calls:
115+
await output.send(("call", tool_call.model_dump_json()))
116+
tasks.append(
117+
asyncio.create_task(ctx.run(call_tool, tool_call))
118+
)
119+
for id_, tool_result in await asyncio.gather(*tasks):
120+
await output.send(("tool", tool_result))
121+
history.append({
122+
"role": "tool",
123+
"tool_call_id": id_,
124+
"content": tool_result,
125+
})
126+
except SignalInterrupt:
127+
await output.send(("assistant", "[Interrupted]"))
128+
history.append({
129+
"role": "assistant",
130+
"content": "[Interrupted]",
131+
})
132+
break
132133

133134

134135
@duron.effect
@@ -162,13 +163,13 @@ async def main() -> None:
162163
async with duron.invoke(
163164
agent_fn, log_storage, tracer=Tracer(args.session_id)
164165
) as job:
165-
input_stream: StreamWriter[str] = job.open_stream("input_", "w")
166-
signal_stream: StreamWriter[None] = job.open_stream("signal", "w")
167-
stream: Stream[tuple[str, str]] = job.open_stream("output", "r")
166+
input_stream: Awaitable[StreamWriter[str]] = job.open_stream("input_", "w")
167+
signal_stream: Awaitable[StreamWriter[None]] = job.open_stream("signal", "w")
168+
stream: Awaitable[Stream[tuple[str, str]]] = job.open_stream("output", "r")
168169

169170
async def reader() -> None:
170171
console = Console()
171-
async for role, result in stream:
172+
async for role, result in await stream:
172173
match role:
173174
case "user":
174175
console.print("[bold cyan] USER[/bold cyan]", result)
@@ -182,20 +183,21 @@ async def reader() -> None:
182183
console.print("[bold magenta] ???[/bold magenta]", result)
183184

184185
async def writer() -> None:
186+
signal_stream_ = await signal_stream
187+
input_stream_ = await input_stream
185188
while True:
186189
await asyncio.sleep(0)
187190
m = await asyncio.to_thread(input, "> ")
188191
if m.strip():
189192
if m == "!":
190-
await signal_stream.send(None)
193+
await signal_stream_.send(None)
191194
else:
192-
await input_stream.send(m)
195+
await input_stream_.send(m)
193196

194-
bg = [asyncio.create_task(reader()), asyncio.create_task(writer())]
195197
await job.start()
196-
await job.wait()
197-
for t in bg:
198-
await t
198+
await asyncio.gather(
199+
job.wait(), asyncio.create_task(reader()), asyncio.create_task(writer())
200+
)
199201

200202

201203
async def completion(

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ ban-relative-imports = "all"
9494

9595
[tool.ruff.lint.flake8-tidy-imports.banned-api]
9696
"typing.TypedDict".msg = "Use typing_extensions.TypedDict instead."
97+
"typing_extensions.AsyncContextManager".msg = "Use contextlib.AbstractAsyncContextManager instead."
9798

9899
[tool.ruff.format]
99100
preview = true

src/duron/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from duron._core.signal import SignalInterrupt as SignalInterrupt
77
from duron._core.stream import Stream as Stream
88
from duron._core.stream import StreamClosed as StreamClosed
9-
from duron._core.stream import StreamOp as StreamOp
109
from duron._core.stream import StreamWriter as StreamWriter
1110
from duron._decorator.durable import durable as durable
1211
from duron._decorator.effect import effect as effect

src/duron/_core/context.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ async def run(
9090
if isinstance(fn, StatefulFn):
9191
async with self.stream(
9292
cast("StatefulFn[_P, _T, Any]", fn), *args, **kwargs
93-
) as stream:
93+
) as (stream, result):
9494
await stream.discard()
95-
return await stream
95+
return await result
9696

9797
if isinstance(fn, EffectFn):
9898
callable_ = fn.fn
@@ -127,7 +127,7 @@ async def wrapper( # noqa: RUF029
127127

128128
def stream(
129129
self, fn: StatefulFn[_P, _T, _S], /, *args: _P.args, **kwargs: _P.kwargs
130-
) -> AbstractAsyncContextManager[Stream[_S, _T]]:
130+
) -> AbstractAsyncContextManager[tuple[Stream[_S], Awaitable[_T]]]:
131131
"""Stream stateful function partial results.
132132
133133
Args:
@@ -155,7 +155,10 @@ async def create_stream(
155155
*,
156156
name: str | None = None,
157157
labels: Mapping[str, str] | None = None,
158-
) -> tuple[Stream[_T, None], StreamWriter[_T]]:
158+
) -> tuple[
159+
Stream[_T],
160+
StreamWriter[_T],
161+
]:
159162
"""Create a new stream within the context.
160163
161164
Args:
@@ -189,7 +192,10 @@ async def create_signal(
189192
*,
190193
name: str | None = None,
191194
labels: Mapping[str, str] | None = None,
192-
) -> tuple[Signal[_T], StreamWriter[_T]]:
195+
) -> tuple[
196+
Signal[_T],
197+
StreamWriter[_T],
198+
]:
193199
"""Create a new signal within the context.
194200
195201
Args:

0 commit comments

Comments
 (0)