Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,10 @@ def _build_messages_snapshot(
}
)

# Add reasoning messages so frontends that reconcile state from
# MESSAGES_SNAPSHOT retain reasoning content after streaming ends.
all_messages.extend(flow.reasoning_messages)

return MessagesSnapshotEvent(messages=all_messages) # type: ignore[arg-type]


Expand Down Expand Up @@ -1061,7 +1065,9 @@ async def run_agent_stream(

# Emit MessagesSnapshotEvent if we have tool calls or results
# Feature #5: Suppress intermediate snapshots for predictive tools without confirmation
should_emit_snapshot = flow.pending_tool_calls or flow.tool_results or flow.accumulated_text
should_emit_snapshot = (
flow.pending_tool_calls or flow.tool_results or flow.accumulated_text or flow.reasoning_messages
)
if should_emit_snapshot:
# Check if we should suppress for predictive tool
last_tool_name = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,10 @@ def _filter_modified_args(
# Handle standard tool result messages early (role="tool") to preserve provider invariants
# This path maps AG‑UI tool messages to function_result content with the correct tool_call_id
role_str = normalize_agui_role(msg.get("role", "user"))
if role_str == "reasoning":
# Reasoning messages are UI-only state carried in MESSAGES_SNAPSHOT.
# They should not be forwarded to the LLM provider.
continue
if role_str == "tool":
# Prefer explicit tool_call_id fields; fall back to backend fields only if necessary
tool_call_id = msg.get("tool_call_id") or msg.get("toolCallId")
Expand Down Expand Up @@ -1020,6 +1024,11 @@ def agui_messages_to_snapshot_format(messages: list[dict[str, Any]]) -> list[dic
elif "toolCallId" not in normalized_msg:
normalized_msg["toolCallId"] = ""

# Normalize encrypted_value to encryptedValue for reasoning messages
if normalized_msg.get("role") == "reasoning" and "encrypted_value" in normalized_msg:
normalized_msg["encryptedValue"] = normalized_msg["encrypted_value"]
del normalized_msg["encrypted_value"]

result.append(normalized_msg)

return result
40 changes: 38 additions & 2 deletions python/packages/ag-ui/agent_framework_ag_ui/_run_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ class FlowState:
tool_results: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
tool_calls_ended: set[str] = field(default_factory=set) # pyright: ignore[reportUnknownVariableType]
interrupts: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
reasoning_messages: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
accumulated_reasoning: dict[str, str] = field(default_factory=dict) # pyright: ignore[reportUnknownVariableType]

def get_tool_name(self, call_id: str | None) -> str | None:
"""Get tool name by call ID."""
Expand Down Expand Up @@ -460,7 +462,7 @@ def _emit_mcp_tool_result(
return _emit_tool_result_common(content.call_id, raw_output, flow, predictive_handler)


def _emit_text_reasoning(content: Content) -> list[BaseEvent]:
def _emit_text_reasoning(content: Content, flow: FlowState | None = None) -> list[BaseEvent]:
"""Emit AG-UI reasoning events for text_reasoning content.

Uses the protocol-defined reasoning event types so that AG-UI consumers
Expand All @@ -470,6 +472,10 @@ def _emit_text_reasoning(content: Content) -> list[BaseEvent]:
``content.protected_data`` is present it is emitted as a
``ReasoningEncryptedValueEvent`` so that consumers can persist encrypted
reasoning for state continuity without conflating it with display text.

When *flow* is provided the reasoning message is persisted into
``flow.reasoning_messages`` so that ``_build_messages_snapshot`` can
include it in the final ``MESSAGES_SNAPSHOT``.
"""
text = content.text or ""
if not text and content.protected_data is None:
Expand Down Expand Up @@ -498,6 +504,36 @@ def _emit_text_reasoning(content: Content) -> list[BaseEvent]:

events.append(ReasoningEndEvent(message_id=message_id))

# Persist reasoning into flow state for MESSAGES_SNAPSHOT.
# Accumulate reasoning text per message_id, similar to flow.accumulated_text,
# so that incremental deltas build the full reasoning string.
if flow is not None:
if text:
previous_text = flow.accumulated_reasoning.get(message_id, "")
flow.accumulated_reasoning[message_id] = previous_text + text
full_text = flow.accumulated_reasoning.get(message_id, text or "")

# Update existing reasoning entry for this message_id if present; otherwise append a new one.
existing_entry: dict[str, Any] | None = None
for entry in flow.reasoning_messages:
if isinstance(entry, dict) and entry.get("id") == message_id:
existing_entry = entry
break

if existing_entry is None:
reasoning_entry: dict[str, Any] = {
"id": message_id,
"role": "reasoning",
"content": full_text,
}
if content.protected_data is not None:
reasoning_entry["encryptedValue"] = content.protected_data
flow.reasoning_messages.append(reasoning_entry)
else:
existing_entry["content"] = full_text
if content.protected_data is not None:
existing_entry["encryptedValue"] = content.protected_data

return events


Expand Down Expand Up @@ -527,6 +563,6 @@ def _emit_content(
if content_type == "mcp_server_tool_result":
return _emit_mcp_tool_result(content, flow, predictive_handler)
if content_type == "text_reasoning":
return _emit_text_reasoning(content)
return _emit_text_reasoning(content, flow)
logger.debug("Skipping unsupported content type in AG-UI emitter: %s", content_type)
return []
4 changes: 2 additions & 2 deletions python/packages/ag-ui/agent_framework_ag_ui/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"system": "system",
}

ALLOWED_AGUI_ROLES: set[str] = {"user", "assistant", "system", "tool"}
ALLOWED_AGUI_ROLES: set[str] = {"user", "assistant", "system", "tool", "reasoning"}


def generate_event_id() -> str:
Expand Down Expand Up @@ -82,7 +82,7 @@ def normalize_agui_role(raw_role: Any) -> str:
raw_role: Raw role value from AG-UI message

Returns:
Normalized role string (user, assistant, system, or tool)
Normalized role string (user, assistant, system, tool, or reasoning)
"""
if not isinstance(raw_role, str):
return "user"
Expand Down
91 changes: 91 additions & 0 deletions python/packages/ag-ui/tests/ag_ui/test_message_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1669,3 +1669,94 @@ def test_agui_fresh_approval_is_still_processed():
assert len(approval_contents) == 1, "Fresh approval should produce function_approval_response"
assert approval_contents[0].approved is True
assert approval_contents[0].function_call.name == "get_datetime"


class TestReasoningRoundTrip:
"""Tests for reasoning message handling in inbound/outbound adapters."""

def test_reasoning_skipped_on_inbound(self):
"""Reasoning messages from prior snapshot are not forwarded to the LLM."""
messages_input = [
{"id": "u1", "role": "user", "content": "Hello"},
{"id": "r1", "role": "reasoning", "content": "Thinking..."},
{"id": "a1", "role": "assistant", "content": "Hi there"},
]

result = agui_messages_to_agent_framework(messages_input)

roles = [m.role if hasattr(m.role, "value") else str(m.role) for m in result]
assert "reasoning" not in roles
assert len(result) == 2

def test_reasoning_preserved_in_snapshot_format(self):
"""Reasoning messages retain their role through snapshot normalization."""
messages_input = [
{"id": "u1", "role": "user", "content": "Hello"},
{"id": "r1", "role": "reasoning", "content": "Thinking about this..."},
{"id": "a1", "role": "assistant", "content": "Answer"},
]

result = agui_messages_to_snapshot_format(messages_input)

reasoning_msgs = [m for m in result if m.get("role") == "reasoning"]
assert len(reasoning_msgs) == 1
assert reasoning_msgs[0]["content"] == "Thinking about this..."

def test_reasoning_with_encrypted_value_in_snapshot_format(self):
"""Reasoning with encryptedValue passes through snapshot normalization."""
messages_input = [
{
"id": "r1",
"role": "reasoning",
"content": "visible",
"encryptedValue": "secret-data",
},
]

result = agui_messages_to_snapshot_format(messages_input)

assert len(result) == 1
assert result[0]["role"] == "reasoning"
assert result[0]["encryptedValue"] == "secret-data"

def test_reasoning_encrypted_value_snake_case_normalized(self):
"""Snake-case encrypted_value is normalized to encryptedValue in snapshot format."""
messages_input = [
{
"id": "r1",
"role": "reasoning",
"content": "visible",
"encrypted_value": "snake-case-data",
},
]

result = agui_messages_to_snapshot_format(messages_input)

assert len(result) == 1
assert result[0]["encryptedValue"] == "snake-case-data"
assert "encrypted_value" not in result[0]

def test_multi_turn_with_reasoning_in_prior_snapshot(self):
"""Second turn with reasoning from prior snapshot does not corrupt messages."""
messages_input = [
{"id": "u1", "role": "user", "content": "First question"},
{"id": "r1", "role": "reasoning", "content": "Prior reasoning"},
{"id": "a1", "role": "assistant", "content": "First answer"},
{"id": "u2", "role": "user", "content": "Follow-up question"},
]

result = agui_messages_to_agent_framework(messages_input)

roles = [m.role if hasattr(m.role, "value") else str(m.role) for m in result]
# Reasoning is filtered out, other messages preserved in order
assert roles == ["user", "assistant", "user"]
# Content not corrupted
texts = []
for m in result:
for c in m.contents or []:
if hasattr(c, "text") and c.text:
texts.append(c.text)
assert "First question" in texts
assert "First answer" in texts
assert "Follow-up question" in texts
assert "Prior reasoning" not in texts
155 changes: 155 additions & 0 deletions python/packages/ag-ui/tests/ag_ui/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,3 +1346,158 @@ def test_routes_text_reasoning(self):

assert len(events) == 5
assert isinstance(events[0], ReasoningStartEvent)


class TestReasoningInSnapshot:
"""Tests for reasoning message inclusion in MESSAGES_SNAPSHOT."""

def test_reasoning_persisted_to_flow_state(self):
"""_emit_text_reasoning with flow persists reasoning into flow.reasoning_messages."""
flow = FlowState()
content = Content.from_text_reasoning(
id="reason_persist",
text="Let me think step by step.",
)

_emit_text_reasoning(content, flow)

assert len(flow.reasoning_messages) == 1
assert flow.reasoning_messages[0]["id"] == "reason_persist"
assert flow.reasoning_messages[0]["role"] == "reasoning"
assert flow.reasoning_messages[0]["content"] == "Let me think step by step."
assert "encryptedValue" not in flow.reasoning_messages[0]

def test_reasoning_with_encrypted_value_persisted(self):
"""Reasoning with protected_data preserves encryptedValue in flow state."""
flow = FlowState()
content = Content.from_text_reasoning(
id="reason_enc",
text="visible reasoning",
protected_data="encrypted-data-123",
)

_emit_text_reasoning(content, flow)

assert len(flow.reasoning_messages) == 1
assert flow.reasoning_messages[0]["encryptedValue"] == "encrypted-data-123"

def test_snapshot_includes_reasoning(self):
"""_build_messages_snapshot includes reasoning messages from flow state."""
from agent_framework_ag_ui._agent_run import _build_messages_snapshot

flow = FlowState()
flow.accumulated_text = "Here is my answer."
flow.reasoning_messages = [
{"id": "r1", "role": "reasoning", "content": "Thinking..."},
]

snapshot = _build_messages_snapshot(flow, [])

roles = [m.get("role") if isinstance(m, dict) else getattr(m, "role", None) for m in snapshot.messages]
assert "reasoning" in roles

def test_snapshot_preserves_reasoning_encrypted_value(self):
"""Snapshot reasoning with encryptedValue is preserved end-to-end."""
from agent_framework_ag_ui._agent_run import _build_messages_snapshot

flow = FlowState()
content = Content.from_text_reasoning(
id="reason_e2e",
text="visible",
protected_data="secret-data",
)
_emit_text_reasoning(content, flow)

text_content = Content.from_text("Final answer.")
_emit_text(text_content, flow)

snapshot = _build_messages_snapshot(flow, [])

reasoning_msgs = [
m
for m in snapshot.messages
if (m.get("role") if isinstance(m, dict) else getattr(m, "role", None)) == "reasoning"
]
assert len(reasoning_msgs) == 1
msg = reasoning_msgs[0]
if isinstance(msg, dict):
assert msg["content"] == "visible"
assert msg["encryptedValue"] == "secret-data"

def test_emit_content_routes_reasoning_with_flow(self):
"""_emit_content passes flow to _emit_text_reasoning for persistence."""
flow = FlowState()
content = Content.from_text_reasoning(text="routed reasoning")

_emit_content(content, flow)

assert len(flow.reasoning_messages) == 1
assert flow.reasoning_messages[0]["content"] == "routed reasoning"

def test_reasoning_without_flow_does_not_error(self):
"""Calling _emit_text_reasoning without flow still works (backward compat)."""
content = Content.from_text_reasoning(text="no flow")

events = _emit_text_reasoning(content)

assert len(events) == 5
assert isinstance(events[0], ReasoningStartEvent)

def test_snapshot_reasoning_ordering(self):
"""Reasoning messages appear after assistant text in snapshot."""
from agent_framework_ag_ui._agent_run import _build_messages_snapshot

flow = FlowState()
reasoning_content = Content.from_text_reasoning(id="r1", text="Thinking...")
_emit_text_reasoning(reasoning_content, flow)

text_content = Content.from_text("Answer")
_emit_text(text_content, flow)

snapshot = _build_messages_snapshot(flow, [{"id": "u1", "role": "user", "content": "Hi"}])

# user -> assistant text -> reasoning
assert len(snapshot.messages) == 3
roles = [m.get("role") if isinstance(m, dict) else getattr(m, "role", None) for m in snapshot.messages]
assert roles == ["user", "assistant", "reasoning"]

def test_reasoning_accumulates_incremental_deltas(self):
"""Multiple reasoning deltas with the same id accumulate into one entry."""
flow = FlowState()
content1 = Content.from_text_reasoning(id="reason_inc", text="First ")
content2 = Content.from_text_reasoning(id="reason_inc", text="second ")
content3 = Content.from_text_reasoning(id="reason_inc", text="third.")

_emit_text_reasoning(content1, flow)
_emit_text_reasoning(content2, flow)
_emit_text_reasoning(content3, flow)

assert len(flow.reasoning_messages) == 1
assert flow.reasoning_messages[0]["id"] == "reason_inc"
assert flow.reasoning_messages[0]["content"] == "First second third."

def test_reasoning_accumulates_distinct_message_ids(self):
"""Reasoning entries with different ids are stored separately."""
flow = FlowState()
content_a = Content.from_text_reasoning(id="a", text="alpha")
content_b = Content.from_text_reasoning(id="b", text="beta")

_emit_text_reasoning(content_a, flow)
_emit_text_reasoning(content_b, flow)

assert len(flow.reasoning_messages) == 2
assert flow.reasoning_messages[0]["content"] == "alpha"
assert flow.reasoning_messages[1]["content"] == "beta"

def test_reasoning_encrypted_value_updated_on_later_delta(self):
"""encryptedValue is set even when it arrives with a later delta."""
flow = FlowState()
content1 = Content.from_text_reasoning(id="enc_late", text="part1 ")
content2 = Content.from_text_reasoning(id="enc_late", text="part2", protected_data="encrypted-payload")

_emit_text_reasoning(content1, flow)
_emit_text_reasoning(content2, flow)

assert len(flow.reasoning_messages) == 1
assert flow.reasoning_messages[0]["content"] == "part1 part2"
assert flow.reasoning_messages[0]["encryptedValue"] == "encrypted-payload"
1 change: 1 addition & 0 deletions python/packages/ag-ui/tests/ag_ui/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def test_normalize_agui_role_valid():
assert normalize_agui_role("assistant") == "assistant"
assert normalize_agui_role("system") == "system"
assert normalize_agui_role("tool") == "tool"
assert normalize_agui_role("reasoning") == "reasoning"


def test_normalize_agui_role_invalid():
Expand Down
Loading