Skip to content

Commit 38e1192

Browse files
committed
turn on workflow sandboxing
1 parent 8e5ea40 commit 38e1192

12 files changed

Lines changed: 84 additions & 78 deletions

openai_agents/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# POC of Temporal integration with OpenAI Agents SDK
1+
# Pre-release of Temporal integration with OpenAI Agents SDK
22

3-
Temporal equivalent of sample code from OpenAI Agents SDK:
3+
These sampoles are adatapted from the code in the OpenAI Agents SDK:
44

55
https://github.com/openai/openai-agents-python/tree/main/examples
66

openai_agents/customer_service_client.py renamed to openai_agents/run_customer_service_client.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22
import asyncio
33
from ftplib import print_line
44

5+
from temporalio import workflow
56
from temporalio.client import Client, WorkflowQueryRejectedError, WorkflowUpdateFailedError
67
from temporalio.common import WorkflowIDReusePolicy, QueryRejectCondition
78
from temporalio.service import RPCError, RPCStatusCode
89

9-
from openai_agents.workflows.customer_service_workflow import CustomerServiceWorkflow, ProcessUserMessageInput
10+
with workflow.unsafe.imports_passed_through():
11+
from temporalio.contrib.openai_agents.open_ai_data_converter import open_ai_data_converter
12+
13+
from openai_agents.workflows.customer_service_workflow import CustomerServiceWorkflow, ProcessUserMessageInput
1014

1115

1216
async def main():
@@ -15,7 +19,10 @@ async def main():
1519
args = parser.parse_args()
1620

1721
# Create client connected to server at the given address
18-
client = await Client.connect("localhost:7233")
22+
client = await Client.connect(
23+
"localhost:7233",
24+
data_converter=open_ai_data_converter,
25+
)
1926

2027
handle = client.get_workflow_handle(args.conversation_id)
2128

openai_agents/run_hello_world_workflow.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
set_open_ai_agent_temporal_overrides,
1010
)
1111

12-
1312
async def main():
1413
# Create client connected to server at the given address
1514
client = await Client.connect("localhost:7233")

openai_agents/run_research_workflow.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,12 @@
33
from temporalio.client import Client
44
from temporalio.common import WorkflowIDReusePolicy
55

6-
# from openai_agents.adapters.open_ai_converter import open_ai_data_converter
76
from openai_agents.workflows.research_bot_workflow import ResearchWorkflow
87

9-
108
async def main():
119
# Create client connected to server at the given address
1210
client = await Client.connect(
1311
"localhost:7233",
14-
# data_converter=open_ai_data_converter
1512
)
1613

1714
# Execute a workflow

openai_agents/run_tools_workflow.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,11 @@
33
from temporalio.client import Client
44
from temporalio.common import WorkflowIDReusePolicy
55

6-
# from openai_agents.adapters.open_ai_converter import open_ai_data_converter
76
from openai_agents.workflows.tools_workflow import ToolsWorkflow
87

9-
108
async def main():
119
# Create client connected to server at the given address
12-
client = await Client.connect(
13-
"localhost:7233",
14-
#data_converter=open_ai_data_converter
15-
)
10+
client = await Client.connect("localhost:7233")
1611

1712
# Execute a workflow
1813
result = await client.execute_workflow(ToolsWorkflow.run, "What is the weather in Tokio?", id="tools-workflow",

openai_agents/run_worker.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,21 @@
33
import asyncio
44
import concurrent.futures
55

6+
from temporalio import workflow
7+
68
from temporalio.client import Client
79
from temporalio.worker import Worker
810
from temporalio.contrib.openai_agents.invoke_model_activity import ModelActivity
11+
from temporalio.contrib.openai_agents.open_ai_data_converter import open_ai_data_converter
912

10-
from openai_agents.workflows.hello_world_workflow import HelloWorldAgent
11-
from openai_agents.workflows.tools_workflow import ToolsWorkflow
12-
from openai_agents.workflows.research_bot_workflow import ResearchWorkflow
13-
from openai_agents.workflows.customer_service_workflow import CustomerServiceWorkflow
14-
from openai_agents.workflows.agents_as_tools_workflow import AgentsAsToolsWorkflow
13+
with workflow.unsafe.imports_passed_through():
14+
from openai_agents.workflows.hello_world_workflow import HelloWorldAgent
15+
from openai_agents.workflows.tools_workflow import ToolsWorkflow
16+
from openai_agents.workflows.research_bot_workflow import ResearchWorkflow
17+
from openai_agents.workflows.customer_service_workflow import CustomerServiceWorkflow
18+
from openai_agents.workflows.agents_as_tools_workflow import AgentsAsToolsWorkflow
1519

16-
from openai_agents.workflows.get_weather_activity import get_weather
20+
from openai_agents.workflows.get_weather_activity import get_weather
1721

1822

1923
from temporalio.contrib.openai_agents.temporal_openai_agents import (
@@ -24,21 +28,27 @@
2428
async def main():
2529
with set_open_ai_agent_temporal_overrides():
2630
# Create client connected to server at the given address
27-
client = await Client.connect("localhost:7233")
31+
client = await Client.connect(
32+
"localhost:7233",
33+
data_converter=open_ai_data_converter,
34+
)
2835

2936
model_activity = ModelActivity(model_provider=None)
30-
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as activity_executor:
37+
with concurrent.futures.ThreadPoolExecutor(max_workers=20) as activity_executor:
3138
worker = Worker(
3239
client,
3340
task_queue="my-task-queue",
34-
workflows=[HelloWorldAgent,
41+
workflows=[
42+
HelloWorldAgent,
3543
ToolsWorkflow,
3644
ResearchWorkflow,
3745
CustomerServiceWorkflow,
3846
AgentsAsToolsWorkflow,
39-
],
40-
activities=[model_activity.invoke_model_activity, get_weather],
41-
# get_weather
47+
],
48+
activities=[
49+
model_activity.invoke_model_activity,
50+
get_weather,
51+
],
4252
activity_executor=activity_executor,
4353
)
4454
await worker.run()

openai_agents/workflows/agents_as_tools_workflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def synthesizer_agent() -> Agent:
6161
)
6262

6363

64-
@workflow.defn(sandboxed=False)
64+
@workflow.defn
6565
class AgentsAsToolsWorkflow:
6666
@workflow.run
6767
async def run(self, msg: str) -> str:
@@ -78,7 +78,7 @@ async def run(self, msg: str) -> str:
7878
if isinstance(item, MessageOutputItem):
7979
text = ItemHelpers.text_message_output(item)
8080
if text:
81-
print(f" - Translation step: {text}")
81+
workflow.logger.info(f" - Translation step: {text}")
8282

8383
synthesizer_result = await Runner.run(
8484
synthesizer, orchestrator_result.to_input_list(), run_config=config

openai_agents/workflows/customer_service_workflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# from __future__ import annotations as _annotations
1+
from __future__ import annotations as _annotations
22

33
from temporalio import workflow
44

@@ -139,7 +139,7 @@ class ProcessUserMessageInput(BaseModel):
139139
chat_length: int
140140

141141

142-
@workflow.defn(sandboxed=False)
142+
@workflow.defn
143143
class CustomerServiceWorkflow:
144144

145145
def __init__(self, input_items: list[TResponseInputItem] = None):

openai_agents/workflows/hello_world_workflow.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from temporalio import workflow
22

3+
# Import agent Agent and Runner
4+
with workflow.unsafe.imports_passed_through():
5+
from agents import Agent, Runner
36

4-
# Import our activity, passing it through the sandbox
5-
# with workflow.unsafe.imports_passed_through():
6-
from agents import Agent, Runner, RunConfig
7-
8-
9-
@workflow.defn(sandboxed=False)
7+
@workflow.defn
108
class HelloWorldAgent:
119
@workflow.run
1210
async def run(self, prompt: str) -> str:

openai_agents/workflows/research_agents/research_manager.py

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,22 @@
22

33
import asyncio
44

5-
# with workflow.unsafe.imports_passed_through():
6-
from rich.console import Console
5+
from temporalio import workflow
76

8-
from agents import Runner, custom_span, gen_trace_id, trace, RunConfig
9-
10-
from openai_agents.workflows.research_agents.planner_agent import WebSearchPlan, WebSearchItem, new_planner_agent
11-
from openai_agents.workflows.research_agents.printer import Printer
12-
from openai_agents.workflows.research_agents.search_agent import new_search_agent
13-
from openai_agents.workflows.research_agents.writer_agent import ReportData, new_writer_agent
7+
with workflow.unsafe.imports_passed_through():
8+
# TODO: Restore progress updates
9+
# from rich.console import Console
10+
from agents import Runner, custom_span, gen_trace_id, trace, RunConfig
11+
from openai_agents.workflows.research_agents.planner_agent import WebSearchPlan, WebSearchItem, new_planner_agent
12+
# from openai_agents.workflows.research_agents.printer import Printer
13+
from openai_agents.workflows.research_agents.search_agent import new_search_agent
14+
from openai_agents.workflows.research_agents.writer_agent import ReportData, new_writer_agent
1415

1516

1617
class ResearchManager:
1718
def __init__(self):
18-
self.console = Console()
19-
self.printer = Printer(self.console)
19+
# self.console = Console()
20+
# self.printer = Printer(self.console)
2021
self.run_config = RunConfig()
2122
self.search_agent = new_search_agent()
2223
self.planner_agent = new_planner_agent()
@@ -26,27 +27,27 @@ def __init__(self):
2627
async def run(self, query: str) -> str:
2728
trace_id = gen_trace_id()
2829
with trace("Research trace", trace_id=trace_id):
29-
self.printer.update_item(
30-
"trace_id",
31-
f"View trace: https://platform.openai.com/traces/trace?trace_id={trace_id}",
32-
is_done=True,
33-
hide_checkmark=True,
34-
)
35-
36-
self.printer.update_item(
37-
"starting",
38-
"Starting research...",
39-
is_done=True,
40-
hide_checkmark=True,
41-
)
30+
# self.printer.update_item(
31+
# "trace_id",
32+
# f"View trace: https://platform.openai.com/traces/trace?trace_id={trace_id}",
33+
# is_done=True,
34+
# hide_checkmark=True,
35+
# )
36+
37+
# self.printer.update_item(
38+
# "starting",
39+
# "Starting research...",
40+
# is_done=True,
41+
# hide_checkmark=True,
42+
# )
4243
search_plan = await self._plan_searches(query)
4344
search_results = await self._perform_searches(search_plan)
4445
report = await self._write_report(query, search_results)
4546

4647
final_report = f"Report summary\n\n{report.short_summary}"
47-
self.printer.update_item("final_report", final_report, is_done=True)
48+
# self.printer.update_item("final_report", final_report, is_done=True)
4849

49-
self.printer.end()
50+
# self.printer.end()
5051

5152
print("\n\n=====REPORT=====\n\n")
5253
print(f"Report: {report.markdown_report}")
@@ -57,23 +58,23 @@ async def run(self, query: str) -> str:
5758

5859

5960
async def _plan_searches(self, query: str) -> WebSearchPlan:
60-
self.printer.update_item("planning", "Planning searches...")
61+
# self.printer.update_item("planning", "Planning searches...")
6162
result = await Runner.run(
6263
self.planner_agent,
6364
f"Query: {query}",
6465
run_config=self.run_config,
6566
)
66-
self.printer.update_item(
67-
"planning",
68-
f"Will perform {len(result.final_output.searches)} searches",
69-
is_done=True,
70-
)
67+
# self.printer.update_item(
68+
# "planning",
69+
# f"Will perform {len(result.final_output.searches)} searches",
70+
# is_done=True,
71+
# )
7172
return result.final_output_as(WebSearchPlan)
7273

7374

7475
async def _perform_searches(self, search_plan: WebSearchPlan) -> list[str]:
7576
with custom_span("Search the web"):
76-
self.printer.update_item("searching", "Searching...")
77+
# self.printer.update_item("searching", "Searching...")
7778
num_completed = 0
7879
tasks = [asyncio.create_task(self._search(item)) for item in search_plan.searches]
7980
results = []
@@ -82,10 +83,10 @@ async def _perform_searches(self, search_plan: WebSearchPlan) -> list[str]:
8283
if result is not None:
8384
results.append(result)
8485
num_completed += 1
85-
self.printer.update_item(
86-
"searching", f"Searching... {num_completed}/{len(tasks)} completed"
87-
)
88-
self.printer.mark_item_done("searching")
86+
# self.printer.update_item(
87+
# "searching", f"Searching... {num_completed}/{len(tasks)} completed"
88+
# )
89+
# self.printer.mark_item_done("searching")
8990
return results
9091

9192

@@ -103,7 +104,7 @@ async def _search(self, item: WebSearchItem) -> str | None:
103104

104105

105106
async def _write_report(self, query: str, search_results: list[str]) -> ReportData:
106-
self.printer.update_item("writing", "Thinking about report...")
107+
# self.printer.update_item("writing", "Thinking about report...")
107108
input = f"Original query: {query}\nSummarized search results: {search_results}"
108109
result = await Runner.run(
109110
self.writer_agent,
@@ -128,5 +129,5 @@ async def _write_report(self, query: str, search_results: list[str]) -> ReportDa
128129
# next_message += 1
129130
# last_update = time.time()
130131

131-
self.printer.mark_item_done("writing")
132+
# self.printer.mark_item_done("writing")
132133
return result.final_output_as(ReportData)

0 commit comments

Comments
 (0)