Skip to content
Open
53 changes: 37 additions & 16 deletions livekit-agents/livekit/agents/voice/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import asyncio
import copy
import time
from collections.abc import AsyncIterable, Callable, Sequence
from collections.abc import AsyncIterable, Callable, Mapping, Sequence
from contextlib import AbstractContextManager, nullcontext
from contextvars import Token
from dataclasses import dataclass
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Generic,
Literal,
Protocol,
Expand All @@ -18,10 +19,13 @@
runtime_checkable,
)

from google.protobuf.json_format import ParseDict
from google.protobuf.struct_pb2 import Struct
from opentelemetry import context as otel_context, trace
from typing_extensions import TypedDict

from livekit import rtc
from livekit.protocol.agent_pb import agent_session as agent_pb

from .. import cli, inference, llm, stt, tts, utils, vad
from .._exceptions import APIError
Expand Down Expand Up @@ -376,9 +380,9 @@ def __init__(
else DEFAULT_TTS_TEXT_TRANSFORMS
),
ivr_detection=ivr_detection,
use_tts_aligned_transcript=use_tts_aligned_transcript
if is_given(use_tts_aligned_transcript)
else None,
use_tts_aligned_transcript=(
use_tts_aligned_transcript if is_given(use_tts_aligned_transcript) else None
),
aec_warmup_duration=aec_warmup_duration,
session_close_transcript_timeout=session_close_transcript_timeout,
)
Expand Down Expand Up @@ -885,7 +889,7 @@ def _close_soon(
*,
reason: CloseReason,
drain: bool = False,
error: llm.LLMError | stt.STTError | tts.TTSError | llm.RealtimeModelError | None = None,
error: (llm.LLMError | stt.STTError | tts.TTSError | llm.RealtimeModelError | None) = None,
) -> None:
if self._closing_task:
return
Expand All @@ -902,12 +906,14 @@ async def _aclose_impl(
*,
reason: CloseReason,
drain: bool = False,
error: llm.LLMError
| stt.STTError
| tts.TTSError
| llm.RealtimeModelError
| inference.InterruptionDetectionError
| None = None,
error: (
llm.LLMError
| stt.STTError
| tts.TTSError
| llm.RealtimeModelError
| inference.InterruptionDetectionError
| None
) = None,
) -> None:
if self._root_span_context:
# make `activity.drain` and `on_exit` under the root span
Expand Down Expand Up @@ -1098,7 +1104,10 @@ async def _start_ivr_detection(self, transcript: str | None = None) -> None:
self._tools.extend(self._ivr_activity.tools)
await self._ivr_activity.start()
if transcript is not None:
logger.debug("IVR detection started with transcript", extra={"transcript": transcript})
logger.debug(
"IVR detection started with transcript",
extra={"transcript": transcript},
)
self._ivr_activity._on_user_input_transcribed(
UserInputTranscribedEvent(transcript=transcript, is_final=True)
)
Expand Down Expand Up @@ -1215,6 +1224,14 @@ def interrupt(self, *, force: bool = False) -> asyncio.Future[None]:

return self._activity.interrupt(force=force)

def _emit_debug_message(self, payload: Mapping[str, Any]) -> None:
""":meta private: internal — emit a debug/trace payload to the debugger/recorder."""
st = Struct()
ParseDict(payload, st)
# super().emit bypasses AgentSession.emit's narrowed AgentEvent type;
# debug messages ride the proto, not the Pydantic event union.
super().emit("debug_message", agent_pb.DebugMessage(payload=st))

def clear_user_turn(self) -> None:
# clear the transcription or input audio buffer of the user turn
if self._activity is None:
Expand Down Expand Up @@ -1323,7 +1340,8 @@ async def _update_activity(
await activity.aclose()
elif previous_activity == "pause":
reuse_resources = await activity.pause(
blocked_tasks=blocked_tasks or [], new_activity=self._next_activity
blocked_tasks=blocked_tasks or [],
new_activity=self._next_activity,
)

if self._closing and new_activity == "start":
Expand All @@ -1343,17 +1361,20 @@ async def _update_activity(

run_state = self._global_run_state
handoff_item = AgentHandoff(
old_agent_id=previous_activity_v.agent.id if previous_activity_v else None,
old_agent_id=(previous_activity_v.agent.id if previous_activity_v else None),
new_agent_id=self._activity.agent.id,
)
if run_state:
run_state._agent_handoff(
item=handoff_item,
old_agent=previous_activity_v.agent if previous_activity_v else None,
old_agent=(previous_activity_v.agent if previous_activity_v else None),
new_agent=self._activity.agent,
)
self._chat_ctx.insert(handoff_item)
self.emit("conversation_item_added", ConversationItemAddedEvent(item=handoff_item))
self.emit(
"conversation_item_added",
ConversationItemAddedEvent(item=handoff_item),
)

if new_activity == "start":
await self._activity.start(reuse_resources=reuse_resources)
Expand Down
1 change: 1 addition & 0 deletions livekit-agents/livekit/agents/voice/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ async def wait_for_playout(self) -> None:
"speech_created",
"error",
"close",
"debug_message",
]

UserState = Literal["speaking", "listening", "away"]
Expand Down
5 changes: 5 additions & 0 deletions livekit-agents/livekit/agents/voice/remote_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ def register_session(self, session: AgentSession) -> None:
session.on("session_usage_updated", self._on_session_usage_updated)
session.on("overlapping_speech", self._on_overlapping_speech)
session.on("error", self._on_error)
session.on("debug_message", self._on_debug_message)

def register_text_input(self, text_input_cb: TextInputCallback) -> None:
self._text_input_cb = text_input_cb
Expand All @@ -401,6 +402,7 @@ async def aclose(self) -> None:
self._session.off("session_usage_updated", self._on_session_usage_updated)
self._session.off("overlapping_speech", self._on_overlapping_speech)
self._session.off("error", self._on_error)
self._session.off("debug_message", self._on_debug_message)

if self._recv_task:
await utils.aio.cancel_and_wait(self._recv_task)
Expand Down Expand Up @@ -574,6 +576,9 @@ def _on_error(self, event: ErrorEvent) -> None:
)
)

def _on_debug_message(self, event: agent_pb.DebugMessage) -> None:
self._send_event(agent_pb.AgentSessionEvent(debug_message=event))

async def _handle_request_safe(self, req: agent_pb.SessionRequest) -> None:
try:
await self._handle_request(req)
Expand Down
2 changes: 1 addition & 1 deletion livekit-agents/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = [
"certifi>=2025.6.15",
"livekit==1.1.8",
"livekit-api>=1.0.7,<2",
"livekit-protocol>=1.1.8,<2",
"livekit-protocol>=1.1.10,<2",
"livekit-blingfire~=1.1,<2",
"protobuf>=3",
"pyjwt>=2.0",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_session_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def mock_session(self) -> MagicMock:
def test_register_session(self, transport: InMemoryTransport, mock_session: MagicMock) -> None:
host = SessionHost(transport)
host.register_session(mock_session)
assert mock_session.on.call_count == 8
assert mock_session.on.call_count == 9

@pytest.mark.asyncio
async def test_agent_state_changed(self, transport: InMemoryTransport) -> None:
Expand Down Expand Up @@ -589,4 +589,4 @@ async def test_aclose_unregisters_events(self, transport: InMemoryTransport) ->
host.register_session(session)
await host.start()
await host.aclose()
assert session.off.call_count == 8
assert session.off.call_count == 9
Loading