|
2 | 2 |
|
3 | 3 | from temporalio import workflow |
4 | 4 |
|
5 | | - |
6 | 5 | with workflow.unsafe.imports_passed_through(): |
7 | 6 | from pydantic import BaseModel |
8 | 7 | from openai_agents.adapters.activity_model import ModelStubProvider |
@@ -136,55 +135,64 @@ def init_agents() -> Agent[AirlineAgentContext]: |
136 | 135 | return triage_agent |
137 | 136 |
|
138 | 137 |
|
139 | | -### RUN |
| 138 | +class ProcessUserMessageInput(BaseModel): |
| 139 | + user_input: str |
| 140 | + chat_length: int |
| 141 | + |
140 | 142 |
|
141 | 143 | @workflow.defn(sandboxed=False) |
142 | 144 | class CustomerServiceWorkflow: |
143 | 145 |
|
144 | | - def __init__(self): |
| 146 | + def __init__(self, input_items: list[TResponseInputItem] = None): |
145 | 147 | self.run_config = RunConfig(model_provider=ModelStubProvider()) |
146 | 148 | self.chat_history = [] |
147 | | - self.user_input = None |
| 149 | + self.current_agent: Agent[AirlineAgentContext] = init_agents() |
| 150 | + self.context = AirlineAgentContext() |
| 151 | + self.input_items = [] if input_items is None else input_items |
148 | 152 |
|
149 | 153 | @workflow.run |
150 | 154 | async def run(self, input_items: list[TResponseInputItem] = None): |
151 | | - input_items = [] #input_items or [] |
152 | | - current_agent: Agent[AirlineAgentContext] = init_agents() |
153 | | - context = AirlineAgentContext() |
154 | | - while not workflow.info().is_continue_as_new_suggested(): |
155 | | - await workflow.wait_condition(lambda: self.user_input is not None) |
156 | | - self.chat_history.append(f"User: {self.user_input}") |
157 | | - with trace("Customer service", group_id=workflow.info().workflow_id): |
158 | | - input_items.append({"content": self.user_input, "role": "user"}) |
159 | | - result = await Runner.run(current_agent, input_items, context=context, run_config=self.run_config) |
160 | | - |
161 | | - for new_item in result.new_items: |
162 | | - agent_name = new_item.agent.name |
163 | | - if isinstance(new_item, MessageOutputItem): |
164 | | - self.chat_history.append(f"{agent_name}: {ItemHelpers.text_message_output(new_item)}") |
165 | | - elif isinstance(new_item, HandoffOutputItem): |
166 | | - self.chat_history.append( |
167 | | - f"Handed off from {new_item.source_agent.name} to {new_item.target_agent.name}" |
168 | | - ) |
169 | | - elif isinstance(new_item, ToolCallItem): |
170 | | - self.chat_history.append(f"{agent_name}: Calling a tool") |
171 | | - elif isinstance(new_item, ToolCallOutputItem): |
172 | | - self.chat_history.append(f"{agent_name}: Tool call output: {new_item.output}") |
173 | | - else: |
174 | | - self.chat_history.append(f"{agent_name}: Skipping item: {new_item.__class__.__name__}") |
175 | | - input_items = result.to_input_list() |
176 | | - current_agent = result.last_agent |
177 | | - self.user_input = None |
178 | | - workflow.continue_as_new(input_items) |
| 155 | + await workflow.wait_condition( |
| 156 | + lambda: workflow.info().is_continue_as_new_suggested() and workflow.all_handlers_finished()) |
| 157 | + workflow.continue_as_new(self.input_items) |
179 | 158 |
|
180 | 159 | @workflow.query |
181 | 160 | def get_chat_history(self) -> list[str]: |
182 | 161 | return self.chat_history |
183 | 162 |
|
184 | 163 | @workflow.update |
185 | | - async def process_user_message(self, user_input: str) -> list[str]: |
186 | | - await workflow.wait_condition(lambda: self.user_input is None) |
187 | | - self.user_input = user_input |
| 164 | + async def process_user_message(self, input: ProcessUserMessageInput) -> list[str]: |
188 | 165 | length = len(self.chat_history) |
189 | | - await workflow.wait_condition(lambda: self.user_input is None) |
| 166 | + self.chat_history.append(f"User: {input.user_input}") |
| 167 | + with trace("Customer service", group_id=workflow.info().workflow_id): |
| 168 | + self.input_items.append({"content": input.user_input, "role": "user"}) |
| 169 | + result = await Runner.run(self.current_agent, self.input_items, context=self.context, |
| 170 | + run_config=self.run_config) |
| 171 | + |
| 172 | + for new_item in result.new_items: |
| 173 | + agent_name = new_item.agent.name |
| 174 | + if isinstance(new_item, MessageOutputItem): |
| 175 | + self.chat_history.append(f"{agent_name}: {ItemHelpers.text_message_output(new_item)}") |
| 176 | + elif isinstance(new_item, HandoffOutputItem): |
| 177 | + self.chat_history.append( |
| 178 | + f"Handed off from {new_item.source_agent.name} to {new_item.target_agent.name}" |
| 179 | + ) |
| 180 | + elif isinstance(new_item, ToolCallItem): |
| 181 | + self.chat_history.append(f"{agent_name}: Calling a tool") |
| 182 | + elif isinstance(new_item, ToolCallOutputItem): |
| 183 | + self.chat_history.append(f"{agent_name}: Tool call output: {new_item.output}") |
| 184 | + else: |
| 185 | + self.chat_history.append(f"{agent_name}: Skipping item: {new_item.__class__.__name__}") |
| 186 | + self.input_items = result.to_input_list() |
| 187 | + self.current_agent = result.last_agent |
| 188 | + workflow.set_current_details("\n\n".join(self.chat_history)) |
190 | 189 | return self.chat_history[length:] |
| 190 | + |
| 191 | + @process_user_message.validator |
| 192 | + def validate_process_user_message(self, input: ProcessUserMessageInput) -> None: |
| 193 | + if not input.user_input: |
| 194 | + raise ValueError("User input cannot be empty.") |
| 195 | + if len(input.user_input) > 1000: |
| 196 | + raise ValueError("User input is too long. Please limit to 1000 characters.") |
| 197 | + if input.chat_length != len(self.chat_history): |
| 198 | + raise ValueError("Stale chat history. Please refresh the chat.") |
0 commit comments