Skip to content

Commit 268026a

Browse files
committed
feat: stream injection by type
1 parent a5f2e09 commit 268026a

File tree

18 files changed

+358
-202
lines changed

18 files changed

+358
-202
lines changed

examples/agent.py

Lines changed: 106 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import readline
88
from pathlib import Path
99
from typing import TYPE_CHECKING, Literal, cast
10-
from typing_extensions import override
10+
from typing_extensions import Any, override
1111

1212
from openai import AsyncOpenAI, pydantic_function_tool
1313
from openai.lib.streaming.chat import ChatCompletionStreamState
@@ -20,12 +20,11 @@
2020
from rich.console import Console
2121

2222
import duron
23+
from duron import Deferred, Signal, SignalInterrupt, Stream, StreamWriter
2324
from duron.codec import Codec
2425
from duron.contrib.storage import FileLogStorage
2526

2627
if TYPE_CHECKING:
27-
from typing import Any
28-
2928
from duron.codec import JSONValue
3029
from duron.typing import TypeHint
3130

@@ -51,73 +50,83 @@ def decode_json(self, encoded: JSONValue, expected_type: TypeHint[Any]) -> objec
5150
return cast("object", TypeAdapter(expected_type).validate_python(encoded))
5251

5352

54-
@duron.op
55-
async def do_input() -> str: # noqa: RUF029
56-
try:
57-
return input("> ") # noqa: ASYNC250
58-
except EOFError:
59-
os._exit(0)
60-
except KeyboardInterrupt:
61-
os._exit(1)
62-
63-
6453
@duron.fn(codec=PydanticCodec())
65-
async def agent_fn(ctx: duron.Context) -> None:
66-
console = Console()
54+
async def agent_fn(
55+
ctx: duron.Context,
56+
input_: Stream[str] = Deferred,
57+
signal: Signal[None] = Deferred,
58+
output: StreamWriter[tuple[str, str]] = Deferred,
59+
) -> None:
6760
history: list[ChatCompletionMessageParam] = [
6861
{
6962
"role": "system",
7063
"content": "You are a helpful assistant!",
7164
},
7265
]
73-
while True:
74-
msg = await ctx.run(do_input)
75-
history.append({
76-
"role": "user",
77-
"content": msg,
78-
})
79-
console.print("[bold cyan] USER[/bold cyan]", msg)
66+
async with input_ as inp:
8067
while True:
81-
result = await completion(
82-
ctx,
83-
messages=history,
84-
)
85-
if result.choices[0].message.content:
86-
console.print(
87-
"[bold red]ASSISTANT[/bold red] ", result.choices[0].message.content
88-
)
68+
msgs: list[str] = [msgs async for _, msgs in inp.next_nowait(ctx)]
69+
if not msgs:
70+
_, m = await inp.next()
71+
msgs = [m]
72+
8973
history.append({
90-
"role": "assistant",
91-
"content": result.choices[0].message.content,
92-
"tool_calls": [
93-
{
94-
"id": toolcall.id,
95-
"type": "function",
96-
"function": {
97-
"name": toolcall.function.name,
98-
"arguments": toolcall.function.arguments,
99-
},
100-
}
101-
for toolcall in result.choices[0].message.tool_calls or []
102-
if toolcall.type == "function"
103-
],
74+
"role": "user",
75+
"content": "\n".join(msgs),
10476
})
105-
if not result.choices[0].message.tool_calls:
106-
break
107-
108-
tasks: list[asyncio.Task[tuple[str, str]]] = []
109-
for tool_call in result.choices[0].message.tool_calls:
110-
console.print("[bold yellow] CALL[/bold yellow]", tool_call.id)
111-
console.print(tool_call.model_dump_json())
112-
tasks.append(asyncio.create_task(ctx.run(call_tool, None, tool_call)))
113-
for id_, tool_result in await asyncio.gather(*tasks):
114-
console.print("[bold cyan] TOOL[/bold cyan]", id_)
115-
console.print(tool_result)
116-
history.append({
117-
"role": "tool",
118-
"tool_call_id": id_,
119-
"content": tool_result,
120-
})
77+
await output.send(("user", "\n".join(msgs)))
78+
while True:
79+
try:
80+
async with signal:
81+
result = await completion(
82+
ctx,
83+
messages=history,
84+
)
85+
if result.choices[0].message.content:
86+
await output.send((
87+
"assistant",
88+
result.choices[0].message.content,
89+
))
90+
history.append({
91+
"role": "assistant",
92+
"content": result.choices[0].message.content,
93+
"tool_calls": [
94+
{
95+
"id": toolcall.id,
96+
"type": "function",
97+
"function": {
98+
"name": toolcall.function.name,
99+
"arguments": toolcall.function.arguments,
100+
},
101+
}
102+
for toolcall in result.choices[0].message.tool_calls
103+
or []
104+
if toolcall.type == "function"
105+
],
106+
})
107+
if not result.choices[0].message.tool_calls:
108+
break
109+
110+
tasks: list[asyncio.Task[tuple[str, str]]] = []
111+
for tool_call in result.choices[0].message.tool_calls:
112+
await output.send(("call", tool_call.model_dump_json()))
113+
tasks.append(
114+
asyncio.create_task(ctx.run(call_tool, None, tool_call))
115+
)
116+
for id_, tool_result in await asyncio.gather(*tasks):
117+
await output.send(("tool", tool_result))
118+
history.append({
119+
"role": "tool",
120+
"tool_call_id": id_,
121+
"content": tool_result,
122+
})
123+
except SignalInterrupt:
124+
await output.send(("assistant", "[Interrupted]"))
125+
history.append({
126+
"role": "assistant",
127+
"content": "[Interrupted]",
128+
})
129+
break
121130

122131

123132
@duron.op
@@ -149,8 +158,45 @@ async def main() -> None:
149158

150159
log_storage = FileLogStorage(Path("logs") / f"{args.session_id}.jsonl")
151160
async with agent_fn.invoke(log_storage) as job:
161+
input_stream: StreamWriter[str] = job.open_stream("input_", "w")
162+
signal_stream: StreamWriter[None] = job.open_stream("signal", "w")
163+
stream: Stream[tuple[str, str]] = job.open_stream("output", "r")
164+
165+
async def reader() -> None:
166+
console = Console()
167+
async for role, result in stream:
168+
match role:
169+
case "user":
170+
console.print("[bold cyan] USER[/bold cyan]", result)
171+
case "assistant":
172+
console.print("[bold red]ASSISTANT[/bold red] ", result)
173+
case "tool":
174+
console.print("[bold cyan] TOOL[/bold cyan]", result)
175+
case "call":
176+
console.print("[bold yellow] CALL[/bold yellow]", result)
177+
case _:
178+
console.print("[bold magenta] ???[/bold magenta]", result)
179+
180+
async def writer() -> None:
181+
try:
182+
while True:
183+
await asyncio.sleep(0)
184+
m = await asyncio.to_thread(input, "> ")
185+
if m.strip():
186+
if m == "!":
187+
await signal_stream.send(None)
188+
else:
189+
await input_stream.send(m)
190+
except EOFError:
191+
os._exit(0)
192+
except KeyboardInterrupt:
193+
os._exit(1)
194+
195+
bg = [asyncio.create_task(reader()), asyncio.create_task(writer())]
152196
await job.start()
153197
await job.wait()
198+
for t in bg:
199+
await t
154200

155201

156202
async def completion(

pyproject.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ build-backend = "uv_build"
1616
[dependency-groups]
1717
dev = [
1818
"pydantic>=2.11.9",
19-
"pytest>=7.0",
2019
"pytest-asyncio>=0.21",
2120
"pytest-codspeed>=4.1.1",
21+
"pytest>=7.0",
2222
]
2323
type-checking = [
2424
"basedmypy>=2.10.0",
@@ -28,14 +28,14 @@ lint = [
2828
"ruff>=0.12.8",
2929
]
3030
docs = [
31-
"mkdocs>=1.6.1",
3231
"mkdocs-material>=9.6.18",
3332
"mkdocs-mermaid2-plugin>=1.2.1",
33+
"mkdocs>=1.6.1",
3434
"mkdocstrings[python]>=0.30.0",
3535
]
3636
examples = [
37-
"pydantic>=2.11.9",
3837
"openai>=2.2.0",
38+
"pydantic>=2.11.9",
3939
"rich>=14.1.0",
4040
]
4141

@@ -60,6 +60,7 @@ extra-standard-library = ["typing_extensions"]
6060
[tool.ruff.lint.flake8-type-checking]
6161
runtime-evaluated-base-classes = ["typing_extensions.TypedDict"]
6262
runtime-evaluated-decorators = ["duron.fn", "duron.op"]
63+
exempt-modules = ["typing", "typing_extensions"]
6364

6465
[tool.ruff.lint.flake8-tidy-imports.banned-api]
6566
"typing.TypedDict".msg = "Use typing_extensions.TypedDict instead."

src/duron/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
from ._core.stream import StreamWriter as StreamWriter
1010
from ._decorator.fn import fn as fn
1111
from ._decorator.op import op as op
12+
from .typing import Deferred as Deferred

src/duron/_core/context.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
import binascii
55
from contextvars import ContextVar
66
from random import Random
7-
from typing import (
8-
TYPE_CHECKING,
7+
from typing import TYPE_CHECKING, cast
8+
from typing_extensions import (
99
Any,
10+
AsyncContextManager,
1011
ParamSpec,
1112
TypeVar,
12-
cast,
1313
final,
1414
overload,
1515
)
@@ -24,14 +24,14 @@
2424
from collections.abc import Callable, Coroutine
2525
from contextvars import Token
2626
from types import TracebackType
27-
from typing_extensions import AsyncContextManager
2827

2928
from duron._core.options import RunOptions
3029
from duron._core.signal import Signal, SignalWriter
3130
from duron._core.stream import Stream, StreamWriter
3231
from duron._decorator.fn import Fn
3332
from duron._loop import EventLoop
3433
from duron.codec import JSONValue
34+
from duron.typing import TypeHint
3535

3636
_T = TypeVar("_T")
3737
_S = TypeVar("_S")
@@ -152,7 +152,7 @@ def run_stream(
152152

153153
async def create_stream(
154154
self,
155-
dtype: type[_T],
155+
dtype: TypeHint[_T],
156156
*,
157157
external: bool = False,
158158
metadata: dict[str, JSONValue] | None = None,
@@ -169,7 +169,7 @@ async def create_stream(
169169

170170
async def create_signal(
171171
self,
172-
dtype: type[_T],
172+
dtype: TypeHint[_T],
173173
*,
174174
metadata: dict[str, JSONValue] | None = None,
175175
) -> tuple[Signal[_T], SignalWriter[_T]]:

0 commit comments

Comments
 (0)