2525from duron .tracing import Tracer , span
2626
2727if TYPE_CHECKING :
28+ from collections .abc import Awaitable
29+
2830 from duron .typing import JSONValue , TypeHint
2931
3032client = AsyncOpenAI ()
@@ -62,73 +64,72 @@ async def agent_fn(
6264 "content" : "You are a helpful assistant!" ,
6365 },
6466 ]
65- async with input_ as inp :
66- i = 0
67- while True :
68- msgs : list [str ] = [msgs async for _ , msgs in inp .next_nowait (ctx )]
69- if not msgs :
70- _ , m = await inp .next ()
71- msgs = [m ]
72-
73- history .append ({
74- "role" : "user" ,
75- "content" : "\n " .join (msgs ),
76- })
77- await output .send (("user" , "\n " .join (msgs )))
78- with span (f"Round #{ i } " ):
79- i += 1
80- while True :
81- try :
82- async with signal :
83- result = await completion (
84- ctx ,
85- messages = history ,
86- )
87- if result .choices [0 ].message .content :
88- await output .send ((
89- "assistant" ,
90- result .choices [0 ].message .content ,
91- ))
92- history .append ({
93- "role" : "assistant" ,
94- "content" : result .choices [0 ].message .content ,
95- "tool_calls" : [
96- {
97- "id" : toolcall .id ,
98- "type" : "function" ,
99- "function" : {
100- "name" : toolcall .function .name ,
101- "arguments" : toolcall .function .arguments ,
102- },
103- }
104- for toolcall in result .choices [0 ].message .tool_calls
105- or []
106- if toolcall .type == "function"
107- ],
108- })
109- if not result .choices [0 ].message .tool_calls :
110- break
111-
112- tasks : list [asyncio .Task [tuple [str , str ]]] = []
113- for tool_call in result .choices [0 ].message .tool_calls :
114- await output .send (("call" , tool_call .model_dump_json ()))
115- tasks .append (
116- asyncio .create_task (ctx .run (call_tool , tool_call ))
117- )
118- for id_ , tool_result in await asyncio .gather (* tasks ):
119- await output .send (("tool" , tool_result ))
120- history .append ({
121- "role" : "tool" ,
122- "tool_call_id" : id_ ,
123- "content" : tool_result ,
124- })
125- except SignalInterrupt :
126- await output .send (("assistant" , "[Interrupted]" ))
67+ i = 0
68+ while True :
69+ msgs : list [str ] = [msgs async for msgs in input_ .next_nowait (ctx )]
70+ if not msgs :
71+ m = await input_ .next ()
72+ msgs = [m ]
73+
74+ history .append ({
75+ "role" : "user" ,
76+ "content" : "\n " .join (msgs ),
77+ })
78+ await output .send (("user" , "\n " .join (msgs )))
79+ with span (f"Round #{ i } " ):
80+ i += 1
81+ while True :
82+ try :
83+ async with signal :
84+ result = await completion (
85+ ctx ,
86+ messages = history ,
87+ )
88+ if result .choices [0 ].message .content :
89+ await output .send ((
90+ "assistant" ,
91+ result .choices [0 ].message .content ,
92+ ))
12793 history .append ({
12894 "role" : "assistant" ,
129- "content" : "[Interrupted]" ,
95+ "content" : result .choices [0 ].message .content ,
96+ "tool_calls" : [
97+ {
98+ "id" : toolcall .id ,
99+ "type" : "function" ,
100+ "function" : {
101+ "name" : toolcall .function .name ,
102+ "arguments" : toolcall .function .arguments ,
103+ },
104+ }
105+ for toolcall in result .choices [0 ].message .tool_calls
106+ or []
107+ if toolcall .type == "function"
108+ ],
130109 })
131- break
110+ if not result .choices [0 ].message .tool_calls :
111+ break
112+
113+ tasks : list [asyncio .Task [tuple [str , str ]]] = []
114+ for tool_call in result .choices [0 ].message .tool_calls :
115+ await output .send (("call" , tool_call .model_dump_json ()))
116+ tasks .append (
117+ asyncio .create_task (ctx .run (call_tool , tool_call ))
118+ )
119+ for id_ , tool_result in await asyncio .gather (* tasks ):
120+ await output .send (("tool" , tool_result ))
121+ history .append ({
122+ "role" : "tool" ,
123+ "tool_call_id" : id_ ,
124+ "content" : tool_result ,
125+ })
126+ except SignalInterrupt :
127+ await output .send (("assistant" , "[Interrupted]" ))
128+ history .append ({
129+ "role" : "assistant" ,
130+ "content" : "[Interrupted]" ,
131+ })
132+ break
132133
133134
134135@duron .effect
@@ -162,13 +163,13 @@ async def main() -> None:
162163 async with duron .invoke (
163164 agent_fn , log_storage , tracer = Tracer (args .session_id )
164165 ) as job :
165- input_stream : StreamWriter [str ] = job .open_stream ("input_" , "w" )
166- signal_stream : StreamWriter [None ] = job .open_stream ("signal" , "w" )
167- stream : Stream [tuple [str , str ]] = job .open_stream ("output" , "r" )
166+ input_stream : Awaitable [ StreamWriter [str ] ] = job .open_stream ("input_" , "w" )
167+ signal_stream : Awaitable [ StreamWriter [None ] ] = job .open_stream ("signal" , "w" )
168+ stream : Awaitable [ Stream [tuple [str , str ] ]] = job .open_stream ("output" , "r" )
168169
169170 async def reader () -> None :
170171 console = Console ()
171- async for role , result in stream :
172+ async for role , result in await stream :
172173 match role :
173174 case "user" :
174175 console .print ("[bold cyan] USER[/bold cyan]" , result )
@@ -182,20 +183,21 @@ async def reader() -> None:
182183 console .print ("[bold magenta] ???[/bold magenta]" , result )
183184
184185 async def writer () -> None :
186+ signal_stream_ = await signal_stream
187+ input_stream_ = await input_stream
185188 while True :
186189 await asyncio .sleep (0 )
187190 m = await asyncio .to_thread (input , "> " )
188191 if m .strip ():
189192 if m == "!" :
190- await signal_stream .send (None )
193+ await signal_stream_ .send (None )
191194 else :
192- await input_stream .send (m )
195+ await input_stream_ .send (m )
193196
194- bg = [asyncio .create_task (reader ()), asyncio .create_task (writer ())]
195197 await job .start ()
196- await job . wait ()
197- for t in bg :
198- await t
198+ await asyncio . gather (
199+ job . wait (), asyncio . create_task ( reader ()), asyncio . create_task ( writer ())
200+ )
199201
200202
201203async def completion (
0 commit comments