Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 42 additions & 17 deletions python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TextMessageStartEvent,
ToolCallArgsEvent,
ToolCallEndEvent,
ToolCallResultEvent,
ToolCallStartEvent,
)
from agent_framework import (
Expand Down Expand Up @@ -369,6 +370,24 @@ def _handle_step_based_approval(messages: list[Any]) -> list[BaseEvent]:
return events


def _make_approval_tool_result_events(resolved_approval_results: list[Content]) -> list[ToolCallResultEvent]:
"""Build TOOL_CALL_RESULT events for tools executed during approval resolution."""
events: list[ToolCallResultEvent] = []
for resolved in resolved_approval_results:
if resolved.call_id:
raw = resolved.result if resolved.result is not None else ""
result_str = raw if isinstance(raw, str) else json.dumps(make_json_safe(raw))
events.append(
ToolCallResultEvent(
message_id=generate_event_id(),
tool_call_id=resolved.call_id,
content=result_str,
role="tool",
)
)
return events


def _evict_oldest_approvals(registry: dict[str, str], max_size: int = 10_000) -> None:
"""Evict the oldest entries from the pending-approvals registry (LRU).

Expand All @@ -391,7 +410,7 @@ async def _resolve_approval_responses(
run_kwargs: dict[str, Any],
pending_approvals: dict[str, str] | None = None,
thread_id: str = "",
) -> None:
) -> list[Content]:
"""Execute approved function calls and replace approval content with results.

This modifies the messages list in place, replacing function_approval_response
Expand All @@ -407,10 +426,16 @@ async def _resolve_approval_responses(
When provided, every approval response is validated against this
registry to prevent bypass, function name spoofing, and replay.
thread_id: The conversation thread ID used to scope registry keys.

Returns:
List of approved function_result Content objects only (empty if no
approvals). Rejection results are written into the message history
but are *not* included in the return value because they should not
be emitted as TOOL_CALL_RESULT events.
"""
fcc_todo = _collect_approval_responses(messages)
if not fcc_todo:
return
return []

approved_responses = [resp for resp in fcc_todo.values() if resp.approved]
rejected_responses = [resp for resp in fcc_todo.values() if not resp.approved]
Expand Down Expand Up @@ -493,38 +518,32 @@ async def _resolve_approval_responses(
logger.exception("Failed to execute approved tool calls; injecting error results: %s", e)
approved_function_results = []

# Build normalized results for approved responses
normalized_results: list[Content] = []
# Build results for approved responses (used for TOOL_CALL_RESULT event emission)
approved_results: list[Content] = []
for idx, approval in enumerate(approved_responses):
if (
idx < len(approved_function_results)
and getattr(approved_function_results[idx], "type", None) == "function_result"
):
normalized_results.append(approved_function_results[idx])
approved_results.append(approved_function_results[idx])
continue
# Get call_id from function_call if present, otherwise use approval.id
func_call = approval.function_call
call_id = (func_call.call_id if func_call else None) or approval.id or ""
normalized_results.append(
approved_results.append(
Content.from_function_result(call_id=call_id, result="Error: Tool call invocation failed.")
)

# Build rejection results
for rejection in rejected_responses:
func_call = rejection.function_call
call_id = (func_call.call_id if func_call else None) or rejection.id or ""
normalized_results.append(
Content.from_function_result(call_id=call_id, result="Error: Tool call invocation was rejected by user.")
)

_replace_approval_contents_with_results(messages, fcc_todo, normalized_results) # type: ignore
_replace_approval_contents_with_results(messages, fcc_todo, approved_results) # type: ignore

# Post-process: Convert user messages with function_result content to proper tool messages.
# After _replace_approval_contents_with_results, approved tool calls have their results
# placed in user messages. OpenAI requires tool results to be in role="tool" messages.
# This transformation ensures the message history is valid for the LLM provider.
_convert_approval_results_to_tool_messages(messages)

return approved_results


def _convert_approval_results_to_tool_messages(messages: list[Message]) -> None:
"""Convert function_result content in user messages to proper tool messages.
Expand Down Expand Up @@ -787,7 +806,9 @@ async def run_agent_stream(
# Resolve approval responses (execute approved tools, replace approvals with results)
# This must happen before running the agent so it sees the tool results
tools_for_execution = tools if tools is not None else server_tools
await _resolve_approval_responses(messages, tools_for_execution, agent, run_kwargs, pending_approvals, thread_id)
resolved_approval_results = await _resolve_approval_responses(
messages, tools_for_execution, agent, run_kwargs, pending_approvals, thread_id
)

# Defense-in-depth: replace approval payloads in snapshot with actual tool results
# so CopilotKit does not re-send stale approval content on subsequent turns.
Expand Down Expand Up @@ -851,6 +872,9 @@ async def run_agent_stream(
yield StateSnapshotEvent(snapshot=flow.current_state)
run_started_emitted = True

for event in _make_approval_tool_result_events(resolved_approval_results):
yield event

# Feature #4: Detect tool-only messages (no text content)
# Emit TextMessageStartEvent to create message context for tool calls
if not flow.message_id and _has_only_tool_calls(update.contents):
Expand Down Expand Up @@ -905,7 +929,8 @@ async def run_agent_stream(
if state_schema and flow.current_state:
yield StateSnapshotEvent(snapshot=flow.current_state)

# Process structured output if response_format is set
for event in _make_approval_tool_result_events(resolved_approval_results):
yield event
if response_format is not None and all_updates:
from agent_framework import AgentResponse
from pydantic import BaseModel
Expand Down
Loading
Loading