Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 138 additions & 12 deletions python/packages/ag-ui/agent_framework_ag_ui/_run_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
from ag_ui.core import (
BaseEvent,
CustomEvent,
ReasoningEncryptedValueEvent,
ReasoningEndEvent,
ReasoningMessageContentEvent,
ReasoningMessageEndEvent,
ReasoningMessageStartEvent,
ReasoningStartEvent,
RunFinishedEvent,
StateSnapshotEvent,
TextMessageContentEvent,
Expand Down Expand Up @@ -224,27 +230,28 @@ def _emit_tool_call(
return events


def _emit_tool_result(
content: Content,
def _emit_tool_result_common(
call_id: str,
raw_result: Any,
flow: FlowState,
predictive_handler: PredictiveStateHandler | None = None,
) -> list[BaseEvent]:
"""Emit ToolCallResult events for function_result content."""
events: list[BaseEvent] = []
"""Shared helper for emitting ToolCallEnd + ToolCallResult events and performing FlowState cleanup.

if not content.call_id:
return events
Both ``_emit_tool_result`` (standard function results) and ``_emit_mcp_tool_result``
(MCP server tool results) delegate to this function.
"""
events: list[BaseEvent] = []

events.append(ToolCallEndEvent(tool_call_id=content.call_id))
flow.tool_calls_ended.add(content.call_id)
events.append(ToolCallEndEvent(tool_call_id=call_id))
flow.tool_calls_ended.add(call_id)

raw_result = content.result if content.result is not None else ""
result_content = raw_result if isinstance(raw_result, str) else json.dumps(make_json_safe(raw_result))
message_id = generate_event_id()
events.append(
ToolCallResultEvent(
message_id=message_id,
tool_call_id=content.call_id,
tool_call_id=call_id,
content=result_content,
role="tool",
)
Expand All @@ -254,7 +261,7 @@ def _emit_tool_result(
{
"id": message_id,
"role": "tool",
"toolCallId": content.call_id,
"toolCallId": call_id,
"content": result_content,
}
)
Expand All @@ -268,14 +275,26 @@ def _emit_tool_result(
flow.tool_call_name = None

if flow.message_id:
logger.debug("Closing text message (issue #3568 fix): message_id=%s", flow.message_id)
logger.debug("Closing text message: message_id=%s", flow.message_id)
events.append(TextMessageEndEvent(message_id=flow.message_id))
flow.message_id = None
flow.accumulated_text = ""

return events


def _emit_tool_result(
content: Content,
flow: FlowState,
predictive_handler: PredictiveStateHandler | None = None,
) -> list[BaseEvent]:
"""Emit ToolCallResult events for function_result content."""
if not content.call_id:
return []
raw_result = content.result if content.result is not None else ""
return _emit_tool_result_common(content.call_id, raw_result, flow, predictive_handler)


def _emit_approval_request(
content: Content,
flow: FlowState,
Expand Down Expand Up @@ -381,6 +400,107 @@ def _emit_oauth_consent(content: Content) -> list[BaseEvent]:
)


def _emit_mcp_tool_call(content: Content, flow: FlowState) -> list[BaseEvent]:
"""Emit ToolCall start/args events for MCP server tool call content.

MCP tool calls arrive as complete items (not streamed deltas), so we emit a
``ToolCallStartEvent`` (and, when arguments are present, a ``ToolCallArgsEvent``)
immediately. This maps MCP-specific fields (tool_name, server_name) to the
same AG-UI ToolCall* events used by regular function calls, making MCP tool
execution visible to AG-UI consumers. Completion/end events are handled
separately by ``_emit_mcp_tool_result``.
"""
events: list[BaseEvent] = []

tool_call_id = content.call_id or generate_event_id()
tool_name = content.tool_name or "mcp_tool"

display_name = tool_name

events.append(
ToolCallStartEvent(
tool_call_id=tool_call_id,
tool_call_name=display_name,
parent_message_id=flow.message_id,
)
)

# Serialize arguments
args_str = ""
if content.arguments:
args_str = (
content.arguments if isinstance(content.arguments, str) else json.dumps(make_json_safe(content.arguments))
)
events.append(ToolCallArgsEvent(tool_call_id=tool_call_id, delta=args_str))

# Track in flow state for MESSAGES_SNAPSHOT
tool_entry = {
"id": tool_call_id,
"type": "function",
"function": {"name": display_name, "arguments": args_str},
}
flow.pending_tool_calls.append(tool_entry)
flow.tool_calls_by_id[tool_call_id] = tool_entry

return events


def _emit_mcp_tool_result(
content: Content, flow: FlowState, predictive_handler: PredictiveStateHandler | None = None
) -> list[BaseEvent]:
"""Emit ToolCallResult events for MCP server tool result content.

Delegates to the shared _emit_tool_result_common helper using content.output
(the MCP-specific result field) instead of content.result.
"""
if not content.call_id:
logger.warning("MCP tool result content missing call_id, skipping")
return []
raw_output = content.output if content.output is not None else ""
return _emit_tool_result_common(content.call_id, raw_output, flow, predictive_handler)


def _emit_text_reasoning(content: Content) -> list[BaseEvent]:
"""Emit AG-UI reasoning events for text_reasoning content.

Uses the protocol-defined reasoning event types so that AG-UI consumers
such as CopilotKit can render reasoning natively.

Only ``content.text`` is used for the visible reasoning message. If
``content.protected_data`` is present it is emitted as a
``ReasoningEncryptedValueEvent`` so that consumers can persist encrypted
reasoning for state continuity without conflating it with display text.
"""
text = content.text or ""
if not text and content.protected_data is None:
return []

message_id = content.id or generate_event_id()

events: list[BaseEvent] = [
ReasoningStartEvent(message_id=message_id),
ReasoningMessageStartEvent(message_id=message_id, role="assistant"),
]

if text:
events.append(ReasoningMessageContentEvent(message_id=message_id, delta=text))

events.append(ReasoningMessageEndEvent(message_id=message_id))

if content.protected_data is not None:
events.append(
ReasoningEncryptedValueEvent(
subtype="message",
entity_id=message_id,
encrypted_value=content.protected_data,
)
)

events.append(ReasoningEndEvent(message_id=message_id))

return events


def _emit_content(
content: Any,
flow: FlowState,
Expand All @@ -402,5 +522,11 @@ def _emit_content(
return _emit_usage(content)
if content_type == "oauth_consent_request":
return _emit_oauth_consent(content)
if content_type == "mcp_server_tool_call":
return _emit_mcp_tool_call(content, flow)
if content_type == "mcp_server_tool_result":
return _emit_mcp_tool_result(content, flow, predictive_handler)
if content_type == "text_reasoning":
return _emit_text_reasoning(content)
logger.debug("Skipping unsupported content type in AG-UI emitter: %s", content_type)
return []
131 changes: 131 additions & 0 deletions python/packages/ag-ui/tests/ag_ui/test_http_round_trip.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,134 @@ def test_sse_response_headers() -> None:

assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
assert response.headers.get("cache-control") == "no-cache"


# ── MCP tool call SSE round-trip ──


def test_mcp_tool_call_sse_round_trip() -> None:
"""MCP tool call + result events survive SSE encoding/parsing round-trip."""
app = _build_app_with_agent(
[
AgentResponseUpdate(
contents=[
Content.from_mcp_server_tool_call(
call_id="mcp-1",
tool_name="search",
server_name="brave",
arguments={"query": "weather"},
)
],
role="assistant",
),
AgentResponseUpdate(
contents=[
Content.from_mcp_server_tool_result(
call_id="mcp-1",
output={"results": ["sunny"]},
)
],
role="assistant",
),
AgentResponseUpdate(
contents=[Content.from_text(text="It's sunny!")],
role="assistant",
),
]
)
client = TestClient(app)
response = client.post("/", json=USER_PAYLOAD)

assert response.status_code == 200
stream = parse_sse_to_event_stream(response.content)
stream.assert_bookends()
stream.assert_tool_calls_balanced()
stream.assert_text_messages_balanced()
stream.assert_no_run_error()

# Verify MCP tool call details survive SSE encoding
start = stream.first("TOOL_CALL_START")
assert start.tool_call_name == "search"
assert start.tool_call_id == "mcp-1"

# Verify the result came through
result = stream.first("TOOL_CALL_RESULT")
assert "sunny" in result.content


# ── Text reasoning SSE round-trip ──


def test_text_reasoning_sse_round_trip() -> None:
"""Text reasoning events survive SSE encoding/parsing round-trip."""
app = _build_app_with_agent(
[
AgentResponseUpdate(
contents=[
Content.from_text_reasoning(
id="reason-1",
text="The user wants weather info, I should use a tool.",
)
],
role="assistant",
),
AgentResponseUpdate(
contents=[Content.from_text(text="Let me check the weather.")],
role="assistant",
),
]
)
client = TestClient(app)
response = client.post("/", json=USER_PAYLOAD)

assert response.status_code == 200
stream = parse_sse_to_event_stream(response.content)
stream.assert_bookends()
stream.assert_text_messages_balanced()
stream.assert_no_run_error()
stream.assert_has_type("REASONING_START")
stream.assert_has_type("REASONING_MESSAGE_CONTENT")
stream.assert_has_type("REASONING_END")

# Verify reasoning content survives SSE encoding
raw_events = parse_sse_response(response.content)
reasoning_content = [e for e in raw_events if e["type"] == "REASONING_MESSAGE_CONTENT"]
assert len(reasoning_content) == 1
assert "weather" in reasoning_content[0]["delta"]


def test_text_reasoning_with_encrypted_value_sse_round_trip() -> None:
"""Reasoning with protected_data emits ReasoningEncryptedValue through SSE."""
app = _build_app_with_agent(
[
AgentResponseUpdate(
contents=[
Content.from_text_reasoning(
id="reason-enc",
text="visible reasoning",
protected_data="encrypted-payload-abc123",
)
],
role="assistant",
),
AgentResponseUpdate(
contents=[Content.from_text(text="Done.")],
role="assistant",
),
]
)
client = TestClient(app)
response = client.post("/", json=USER_PAYLOAD)

assert response.status_code == 200
stream = parse_sse_to_event_stream(response.content)
stream.assert_bookends()
stream.assert_no_run_error()
stream.assert_has_type("REASONING_ENCRYPTED_VALUE")

raw_events = parse_sse_response(response.content)
encrypted = [e for e in raw_events if e["type"] == "REASONING_ENCRYPTED_VALUE"]
assert len(encrypted) == 1
assert encrypted[0]["encryptedValue"] == "encrypted-payload-abc123"
assert encrypted[0]["entityId"] == "reason-enc"
assert encrypted[0]["subtype"] == "message"
Loading
Loading