Skip to content

Commit 265515e

Browse files
danielmillerpclaude
andcommitted
feat(openai-agents): stream tool-call argument deltas to the UI
The TemporalStreamingModel already parsed the Responses API function_call_arguments deltas but only accumulated them — tool calls reached the UI all-at-once via the hooks layer. Now, as the model generates a tool call, the model layer opens a ToolRequestContent streaming context per call (keyed by output_index) and pushes ToolRequestDelta chunks as the arguments arrive, closing on item-done (mirrors the existing text-delta path). on_tool_start no longer emits a duplicate ToolRequestContent (the model layer streams it live now); on_tool_end still emits the ToolResponseContent result. The returned ModelResponse/output_items/usage assembly is unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent e1b31d9 commit 265515e

2 files changed

Lines changed: 79 additions & 32 deletions

File tree

src/agentex/lib/core/temporal/plugins/openai_agents/hooks/hooks.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from agents.tool_context import ToolContext
1414

1515
from agentex.types.text_content import TextContent
16-
from agentex.types.task_message_content import ToolRequestContent, ToolResponseContent
16+
from agentex.types.task_message_content import ToolResponseContent
1717
from agentex.lib.core.observability.llm_metrics_hooks import LLMMetricsHooks
1818
from agentex.lib.core.temporal.plugins.openai_agents.hooks.activities import stream_lifecycle_content
1919

@@ -106,44 +106,25 @@ async def on_agent_end(self, context: RunContextWrapper, agent: Agent, output: A
106106

107107
@override
108108
async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: Tool) -> None: # noqa: ARG002
109-
"""Stream tool request when a tool starts execution.
109+
"""Called when a tool starts execution.
110110
111-
Extracts the tool_call_id and tool_arguments from the context and streams a
112-
ToolRequestContent message to the UI showing that the tool is about to execute.
111+
The tool request (ToolRequestContent) is now streamed live by the model
112+
layer (TemporalStreamingModel) as the function-call arguments arrive over
113+
the Responses API stream. Emitting it again here would double-render the
114+
tool request in the UI, so this hook no longer streams a ToolRequestContent.
115+
The tool *result* is still streamed by on_tool_end (ToolResponseContent),
116+
which the model stream does not produce.
113117
114118
Args:
115119
context: The run context wrapper (will be a ToolContext with tool_call_id and tool_arguments)
116120
agent: The agent executing the tool
117121
tool: The tool being executed
118122
"""
119-
import json
120-
121123
tool_context = context if isinstance(context, ToolContext) else None
122124
tool_call_id = tool_context.tool_call_id if tool_context else f"call_{id(tool)}"
123-
124-
# Extract tool arguments from context
125-
tool_arguments = {}
126-
if tool_context and hasattr(tool_context, 'tool_arguments'):
127-
try:
128-
# tool_arguments is a JSON string, parse it
129-
tool_arguments = json.loads(tool_context.tool_arguments)
130-
except (json.JSONDecodeError, TypeError):
131-
# If parsing fails, log and use empty dict
132-
logger.warning(f"Failed to parse tool arguments: {tool_context.tool_arguments}")
133-
tool_arguments = {}
134-
135-
await workflow.execute_activity(
136-
stream_lifecycle_content,
137-
args=[
138-
self.task_id,
139-
ToolRequestContent(
140-
author="agent",
141-
tool_call_id=tool_call_id,
142-
name=tool.name,
143-
arguments=tool_arguments,
144-
).model_dump(),
145-
],
146-
start_to_close_timeout=self.timeout,
125+
logger.debug(
126+
f"[TemporalStreamingHooks] Tool '{tool.name}' started (tool_call_id={tool_call_id}); "
127+
"tool request is streamed live by the model layer, not re-emitted here."
147128
)
148129

149130
@override

src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,10 @@
6464
from agentex.lib.utils.logging import make_logger
6565
from agentex.lib.core.tracing.tracer import AsyncTracer
6666
from agentex.types.task_message_delta import TextDelta, ReasoningContentDelta, ReasoningSummaryDelta
67+
from agentex.types.tool_request_delta import ToolRequestDelta
6768
from agentex.types.task_message_update import StreamTaskMessageFull, StreamTaskMessageDelta
6869
from agentex.types.task_message_content import TextContent, ReasoningContent
70+
from agentex.types.tool_request_content import ToolRequestContent
6971
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
7072
from agentex.lib.core.temporal.plugins.openai_agents.interceptors.context_interceptor import (
7173
streaming_task_id,
@@ -671,6 +673,7 @@ async def get_response(
671673

672674
# Process events from the Responses API stream
673675
function_calls_in_progress = {} # Track function calls being streamed
676+
tool_call_contexts: dict[int, Any] = {} # Open streaming contexts per function call
674677

675678
async for event in stream:
676679
event_count += 1
@@ -723,14 +726,29 @@ async def get_response(
723726
).__aenter__()
724727
elif item and getattr(item, 'type', None) == 'function_call':
725728
# Track the function call being streamed
729+
call_id = getattr(item, 'call_id', '')
730+
name = getattr(item, 'name', '')
726731
function_calls_in_progress[output_index] = {
727732
'id': getattr(item, 'id', ''),
728-
'call_id': getattr(item, 'call_id', ''),
729-
'name': getattr(item, 'name', ''),
733+
'call_id': call_id,
734+
'name': name,
730735
'arguments': getattr(item, 'arguments', ''),
731736
}
732737
logger.debug(f"[TemporalStreamingModel] Starting function call: {item.name}")
733738

739+
# Open a streaming context so tool-call args stream live to the UI
740+
tool_ctx = await adk.streaming.streaming_task_message_context(
741+
task_id=task_id,
742+
initial_content=ToolRequestContent(
743+
author="agent",
744+
tool_call_id=call_id,
745+
name=name,
746+
arguments={},
747+
),
748+
streaming_mode=self.streaming_mode,
749+
).__aenter__()
750+
tool_call_contexts[output_index] = tool_ctx
751+
734752
elif item and getattr(item, 'type', None) == 'message':
735753
# Track the message being streamed
736754
streaming_context = await adk.streaming.streaming_task_message_context(
@@ -752,6 +770,25 @@ async def get_response(
752770
function_calls_in_progress[output_index]['arguments'] += delta
753771
logger.debug(f"[TemporalStreamingModel] Function call args delta: {delta[:50]}...")
754772

773+
# Stream the args delta live to the UI. The delta event carries
774+
# no call_id/name, so pull them from function_calls_in_progress
775+
# (populated at announce time, keyed by output_index).
776+
ctx = tool_call_contexts.get(output_index)
777+
if ctx is not None and delta:
778+
try:
779+
await ctx.stream_update(StreamTaskMessageDelta(
780+
parent_task_message=ctx.task_message,
781+
delta=ToolRequestDelta(
782+
type="tool_request",
783+
tool_call_id=function_calls_in_progress[output_index]['call_id'],
784+
name=function_calls_in_progress[output_index]['name'],
785+
arguments_delta=delta,
786+
),
787+
type="delta",
788+
))
789+
except Exception as e:
790+
logger.warning(f"Failed to stream tool-call args delta: {e}")
791+
755792
elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent):
756793
# Function call arguments complete
757794
output_index = getattr(event, 'output_index', 0)
@@ -874,6 +911,14 @@ async def get_response(
874911
)
875912
output_items.append(tool_call)
876913

914+
# Close + pop the live-streaming context for this tool call
915+
ctx = tool_call_contexts.pop(output_index, None)
916+
if ctx is not None:
917+
try:
918+
await ctx.close()
919+
except Exception as e:
920+
logger.warning(f"Failed to close tool-call stream context: {e}")
921+
877922
elif isinstance(event, ResponseReasoningSummaryPartAddedEvent):
878923
# New reasoning part/summary started - reset accumulator
879924
part = getattr(event, 'part', None)
@@ -907,6 +952,16 @@ async def get_response(
907952
await streaming_context.close()
908953
streaming_context = None
909954

955+
# Close any tool-call contexts still open (e.g. stream ended without
956+
# a per-item done event). A partial / invalid-JSON args close can
957+
# raise, so guard each one so it never crashes the activity.
958+
for ctx in tool_call_contexts.values():
959+
try:
960+
await ctx.close()
961+
except Exception as e:
962+
logger.warning(f"Failed to close tool-call stream context: {e}")
963+
tool_call_contexts.clear()
964+
910965
# Build the response from output items collected during streaming
911966
# Create output from the items we collected
912967
response_output = []
@@ -1061,6 +1116,17 @@ async def get_response(
10611116

10621117
except Exception as e:
10631118
logger.error(f"Error using Responses API: {e}")
1119+
# Close any tool-call streaming contexts still open so the error
1120+
# path doesn't leak open contexts. A partial / invalid-JSON args
1121+
# close can raise, so guard each one. tool_call_contexts may be
1122+
# unbound if the error fired before the stream loop started.
1123+
for ctx in locals().get("tool_call_contexts", {}).values():
1124+
try:
1125+
await ctx.close()
1126+
except Exception as close_err:
1127+
logger.warning(f"Failed to close tool-call stream context: {close_err}")
1128+
if "tool_call_contexts" in locals():
1129+
tool_call_contexts.clear()
10641130
# LLMMetricsHooks.on_llm_end doesn't fire on error, so emit the
10651131
# failure counter here. Best-effort so the typed LLM exception
10661132
# always propagates intact for retry / circuit-breaker logic.

0 commit comments

Comments
 (0)