Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 83 additions & 11 deletions src/agents/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pydantic import BaseModel
from typing_extensions import assert_never

from .._tool_identity import get_function_tool_lookup_key_for_tool
from .._tool_identity import get_function_tool_lookup_key_for_tool, get_tool_trace_name_for_tool
from ..agent import Agent
from ..exceptions import UserError
from ..handoffs import Handoff
Expand All @@ -20,6 +20,8 @@
from ..run_context import RunContextWrapper, TContext
from ..tool import DEFAULT_APPROVAL_REJECTION_MESSAGE, FunctionTool, invoke_function_tool
from ..tool_context import ToolContext
from ..tracing import Span, agent_span, function_span, handoff_span
from ..tracing.span_data import AgentSpanData
from ..util._approvals import evaluate_needs_approval_setting
from .agent import RealtimeAgent
from .config import RealtimeRunConfig, RealtimeSessionModelSettings, RealtimeUserInput
Expand Down Expand Up @@ -161,6 +163,9 @@ def __init__(
self._guardrail_tasks: set[asyncio.Task[Any]] = set()
self._tool_call_tasks: set[asyncio.Task[Any]] = set()
self._async_tool_calls: bool = bool(self._run_config.get("async_tool_calls", True))
self._current_agent_span: Span[AgentSpanData] | None = None
self._current_agent_trace_tools: list[str] | None = None
self._current_agent_trace_handoffs: list[str] | None = None

@property
def model(self) -> RealtimeModel:
Expand Down Expand Up @@ -395,6 +400,7 @@ async def on_event(self, event: RealtimeModelEvent) -> None:
elif event.type == "connection_status":
pass
elif event.type == "turn_started":
self._start_agent_span()
await self._put_event(
RealtimeAgentStartEvent(
agent=self._current_agent,
Expand All @@ -412,6 +418,7 @@ async def on_event(self, event: RealtimeModelEvent) -> None:
info=self._event_info,
)
)
self._finish_agent_span(reset_current=True)
elif event.type == "exception":
# Store the exception to be raised in __aiter__
self._stored_exception = event.exception
Expand Down Expand Up @@ -646,11 +653,20 @@ async def _handle_tool_call(
tool_arguments=event.arguments,
agent=agent,
)
result = await invoke_function_tool(
function_tool=func_tool,
context=tool_context,
arguments=event.arguments,
)
trace_tool_name = get_tool_trace_name_for_tool(func_tool) or func_tool.name
if self._tracing_enabled():
with function_span(trace_tool_name):
result = await invoke_function_tool(
function_tool=func_tool,
context=tool_context,
arguments=event.arguments,
)
else:
result = await invoke_function_tool(
function_tool=func_tool,
context=tool_context,
arguments=event.arguments,
)

await self._model.send_event(
RealtimeModelSendToolOutput(
Expand Down Expand Up @@ -681,11 +697,20 @@ async def _handle_tool_call(
)

# Execute the handoff to get the new agent
result = await handoff.on_invoke_handoff(self._context_wrapper, event.arguments)
if not isinstance(result, RealtimeAgent):
raise UserError(
f"Handoff {handoff.tool_name} returned invalid result: {type(result)}"
)
if self._tracing_enabled():
with handoff_span(from_agent=agent.name) as span:
result = await handoff.on_invoke_handoff(self._context_wrapper, event.arguments)
if not isinstance(result, RealtimeAgent):
raise UserError(
f"Handoff {handoff.tool_name} returned invalid result: {type(result)}"
)
span.span_data.to_agent = result.name
else:
result = await handoff.on_invoke_handoff(self._context_wrapper, event.arguments)
if not isinstance(result, RealtimeAgent):
raise UserError(
f"Handoff {handoff.tool_name} returned invalid result: {type(result)}"
)

# Store previous agent for event
previous_agent = agent
Expand Down Expand Up @@ -1039,11 +1064,38 @@ def _cleanup_tool_call_tasks(self) -> None:
task.cancel()
self._tool_call_tasks.clear()

def _tracing_enabled(self) -> bool:
return not self._run_config.get("tracing_disabled", False)

def _start_agent_span(self) -> None:
if not self._tracing_enabled():
return

if self._current_agent_span is not None:
self._finish_agent_span(reset_current=True)

current_span = agent_span(
name=self._current_agent.name,
handoffs=self._current_agent_trace_handoffs,
tools=self._current_agent_trace_tools,
output_type="str",
)
current_span.start(mark_as_current=True)
self._current_agent_span = current_span

def _finish_agent_span(self, *, reset_current: bool) -> None:
if self._current_agent_span is None:
return

self._current_agent_span.finish(reset_current=reset_current)
self._current_agent_span = None

async def _cleanup(self) -> None:
"""Clean up all resources and mark session as closed."""
# Cancel and cleanup guardrail tasks
self._cleanup_guardrail_tasks()
self._cleanup_tool_call_tasks()
self._finish_agent_span(reset_current=False)

# Remove ourselves as a listener
self._model.remove_listener(self)
Expand Down Expand Up @@ -1076,6 +1128,17 @@ async def _get_updated_model_settings_from_agent(
updated_settings["instructions"] = instructions or ""
updated_settings["tools"] = tools or []
updated_settings["handoffs"] = handoffs or []
if agent is self._current_agent:
self._current_agent_trace_tools = [
tool_name
for tool in tools
if (tool_name := get_tool_trace_name_for_tool(tool)) is not None
]
self._current_agent_trace_handoffs = [
trace_name
for handoff in handoffs
if (trace_name := self._get_handoff_trace_name(handoff)) is not None
]

# Apply starting settings (from model config) next
if starting_settings:
Expand Down Expand Up @@ -1110,3 +1173,12 @@ async def _check_handoff_enabled(handoff_obj: Handoff[Any, RealtimeAgent[Any]])
results = await asyncio.gather(*(_check_handoff_enabled(h) for h in handoffs))
enabled = [h for h, ok in zip(handoffs, results, strict=False) if ok]
return enabled

@staticmethod
def _get_handoff_trace_name(handoff: Handoff[Any, Any]) -> str | None:
agent_name = getattr(handoff, "agent_name", None)
if isinstance(agent_name, str) and agent_name:
return agent_name

tool_name = getattr(handoff, "tool_name", None)
return tool_name if isinstance(tool_name, str) and tool_name else None
130 changes: 130 additions & 0 deletions tests/realtime/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@
from agents.run_context import RunContextWrapper
from agents.tool import FunctionTool
from agents.tool_context import ToolContext
from agents.tracing import trace
from tests.testing_processor import SPAN_PROCESSOR_TESTING


def _export_finished_spans() -> list[dict[str, Any]]:
span_exports: list[dict[str, Any]] = []
for span in SPAN_PROCESSOR_TESTING.get_ordered_spans(including_empty=True):
exported = span.export()
if exported is not None:
span_exports.append(exported)
return span_exports


class _DummyModel(RealtimeModel):
Expand Down Expand Up @@ -127,6 +138,124 @@ async def consume():
await consumer


@pytest.mark.asyncio
async def test_realtime_turns_create_agent_spans():
model = _DummyModel()

async def invoke_tool(_ctx: ToolContext[Any], _arguments: str) -> str:
return "tool result"

tool = FunctionTool(
name="lookup",
description="lookup",
params_json_schema={"type": "object", "properties": {}},
on_invoke_tool=invoke_tool,
)
target_agent = RealtimeAgent(name="target")
handoff = Handoff(
tool_name="transfer_to_target",
tool_description="transfer",
input_json_schema={},
on_invoke_handoff=AsyncMock(return_value=target_agent),
input_filter=None,
agent_name=target_agent.name,
is_enabled=True,
)
agent = RealtimeAgent(name="agent", tools=[tool], handoffs=[handoff])
session = RealtimeSession(model, agent, None)
await session._get_updated_model_settings_from_agent(None, agent)

with trace("RealtimeAgent Test"):
await session.on_event(RealtimeModelTurnStartedEvent())
await session.on_event(RealtimeModelTurnEndedEvent())

span_exports = _export_finished_spans()
agent_span = next(span for span in span_exports if span["span_data"]["type"] == "agent")

assert agent_span["span_data"]["name"] == "agent"
assert agent_span["span_data"]["tools"] == ["lookup"]
assert agent_span["span_data"]["handoffs"] == ["target"]
assert agent_span["span_data"]["output_type"] == "str"


@pytest.mark.asyncio
async def test_realtime_tracing_disabled_skips_agent_spans():
model = _DummyModel()
agent = RealtimeAgent(name="agent")
session = RealtimeSession(model, agent, None, run_config={"tracing_disabled": True})

with trace("RealtimeAgent Test"):
await session.on_event(RealtimeModelTurnStartedEvent())
await session.on_event(RealtimeModelTurnEndedEvent())

assert SPAN_PROCESSOR_TESTING.get_ordered_spans(including_empty=True) == []


@pytest.mark.asyncio
async def test_realtime_function_tool_calls_create_function_spans():
model = _DummyModel()

async def invoke_tool(_ctx: ToolContext[Any], _arguments: str) -> str:
return "tool result"

tool = FunctionTool(
name="lookup",
description="lookup",
params_json_schema={"type": "object", "properties": {}},
on_invoke_tool=invoke_tool,
)
agent = RealtimeAgent(name="agent", tools=[tool])
session = RealtimeSession(model, agent, None, run_config={"async_tool_calls": False})
await session._get_updated_model_settings_from_agent(None, agent)

with trace("RealtimeAgent Test"):
await session.on_event(RealtimeModelTurnStartedEvent())
await session._handle_tool_call(
RealtimeModelToolCallEvent(name="lookup", call_id="call_1", arguments='{"q": "x"}')
)
await session.on_event(RealtimeModelTurnEndedEvent())

span_exports = _export_finished_spans()
function_span = next(span for span in span_exports if span["span_data"]["type"] == "function")

assert function_span["span_data"]["name"] == "lookup"


@pytest.mark.asyncio
async def test_realtime_handoffs_create_handoff_spans():
model = _DummyModel()
target_agent = RealtimeAgent(name="target")
handoff = Handoff(
tool_name="transfer_to_target",
tool_description="transfer",
input_json_schema={},
on_invoke_handoff=AsyncMock(return_value=target_agent),
input_filter=None,
agent_name=target_agent.name,
is_enabled=True,
)
agent = RealtimeAgent(name="agent", handoffs=[handoff])
session = RealtimeSession(model, agent, None, run_config={"async_tool_calls": False})
await session._get_updated_model_settings_from_agent(None, agent)

with trace("RealtimeAgent Test"):
await session.on_event(RealtimeModelTurnStartedEvent())
await session._handle_tool_call(
RealtimeModelToolCallEvent(
name="transfer_to_target",
call_id="call_1",
arguments="{}",
)
)
await session.on_event(RealtimeModelTurnEndedEvent())

span_exports = _export_finished_spans()
handoff_span = next(span for span in span_exports if span["span_data"]["type"] == "handoff")

assert handoff_span["span_data"]["from_agent"] == "agent"
assert handoff_span["span_data"]["to_agent"] == "target"


@pytest.mark.asyncio
async def test_transcription_completed_adds_new_user_item():
model = _DummyModel()
Expand Down Expand Up @@ -300,6 +429,7 @@ async def close(self):
@pytest.fixture
def mock_agent():
agent = Mock(spec=RealtimeAgent)
agent.name = "agent"
agent.get_all_tools = AsyncMock(return_value=[])

type(agent).handoffs = PropertyMock(return_value=[])
Expand Down
Loading