Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/google/adk/agents/run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ class RunConfig(BaseModel):
session_resumption: Optional[types.SessionResumptionConfig] = None
"""Configures session resumption mechanism. Only support transparent session resumption mode now."""

enable_saving_live_blob: bool = False
"""Saves live video and audio data to session and artifact service.

Right now, only audio is supported.
"""

max_llm_calls: int = 500
"""
A limit on the total number of llm calls for a given run.
Expand Down
25 changes: 9 additions & 16 deletions src/google/adk/flows/llm_flows/audio_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async def flush_caches(
flush_user_audio: bool = True,
flush_model_audio: bool = True,
) -> None:
"""Flush audio caches to session and artifact services.
"""Flush audio caches to artifact services.

The multimodality data is saved in artifact service in the format of
audio file. The file data reference is added to the session as an event.
Expand All @@ -103,32 +103,30 @@ async def flush_caches(
flush_model_audio: Whether to flush the output (model) audio cache.
"""
if flush_user_audio and invocation_context.input_realtime_cache:
success = await self._flush_cache_to_services(
invocation_context.input_realtime_cache = []
logger.debug('Flushed input audio cache')
return await self._flush_cache_to_services(
invocation_context,
invocation_context.input_realtime_cache,
'input_audio',
)
if success:
invocation_context.input_realtime_cache = []
logger.debug('Flushed input audio cache')

if flush_model_audio and invocation_context.output_realtime_cache:
success = await self._flush_cache_to_services(
invocation_context.output_realtime_cache = []
logger.debug('Flushed output audio cache')
return await self._flush_cache_to_services(
invocation_context,
invocation_context.output_realtime_cache,
'output_audio',
)
if success:
invocation_context.output_realtime_cache = []
logger.debug('Flushed output audio cache')

async def _flush_cache_to_services(
self,
invocation_context: InvocationContext,
audio_cache: list[RealtimeCacheEntry],
cache_type: str,
) -> bool:
"""Flush a list of audio cache entries to session and artifact services.
"""Flush a list of audio cache entries to artifact services.

The artifact service stores the actual blob. The session stores the
reference to the stored blob.
Expand Down Expand Up @@ -192,19 +190,14 @@ async def _flush_cache_to_services(
timestamp=audio_cache[0].timestamp,
)

# Add to session
await invocation_context.session_service.append_event(
invocation_context.session, audio_event
)

logger.debug(
'Successfully flushed %s cache: %d chunks, %d bytes, saved as %s',
cache_type,
len(audio_cache),
len(combined_audio_data),
filename,
)
return True
return audio_event

except Exception as e:
logger.error('Failed to flush %s cache: %s', cache_type, e)
Expand Down
145 changes: 83 additions & 62 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,42 @@
DEFAULT_ENABLE_CACHE_STATISTICS = False


def _get_audio_transcription_from_session(
invocation_context: InvocationContext,
) -> list[types.Content]:
"""Get audio and transcription content from session events.

Collects audio file references and transcription text from session events
to reconstruct the conversation history including multimodal content.
Args:
invocation_context: The invocation context containing session data.
Returns:
A list of Content objects containing audio files and transcriptions.
"""
contents = []

for event in invocation_context.session.events:
# Collect transcription text events
if hasattr(event, 'input_transcription') and event.input_transcription:
contents.append(
types.Content(
role='user',
parts=[types.Part.from_text(text=event.input_transcription.text)],
)
)

if hasattr(event, 'output_transcription') and event.output_transcription:
contents.append(
types.Content(
role='model',
parts=[
types.Part.from_text(text=event.output_transcription.text)
],
)
)
return contents


class BaseLlmFlow(ABC):
"""A basic flow that calls the LLM in a loop until a final response is generated.

Expand Down Expand Up @@ -129,25 +165,12 @@ async def run_live(
if llm_request.contents:
# Sends the conversation history to the model.
with tracer.start_as_current_span('send_data'):
if invocation_context.transcription_cache:
from . import audio_transcriber

audio_transcriber = audio_transcriber.AudioTranscriber(
init_client=True
if invocation_context.run_config.input_audio_transcription
is None
else False
)
contents = audio_transcriber.transcribe_file(invocation_context)
logger.debug('Sending history to model: %s', contents)
await llm_connection.send_history(contents)
invocation_context.transcription_cache = None
trace_send_data(invocation_context, event_id, contents)
else:
await llm_connection.send_history(llm_request.contents)
trace_send_data(
invocation_context, event_id, llm_request.contents
)
# Combine regular contents with audio/transcription from session
logger.debug('Sending history to model: %s', llm_request.contents)
await llm_connection.send_history(llm_request.contents)
trace_send_data(
invocation_context, event_id, llm_request.contents
)

send_task = asyncio.create_task(
self._send_to_model(llm_connection, invocation_context)
Expand Down Expand Up @@ -324,22 +347,6 @@ def get_author_for_event(llm_response):
author=get_author_for_event(llm_response),
)

# Handle transcription events ONCE per llm_response, outside the event loop
if llm_response.input_transcription:
await self.transcription_manager.handle_input_transcription(
invocation_context, llm_response.input_transcription
)

if llm_response.output_transcription:
await self.transcription_manager.handle_output_transcription(
invocation_context, llm_response.output_transcription
)

# Flush audio caches based on control events using configurable settings
await self._handle_control_event_flush(
invocation_context, llm_response
)

async with Aclosing(
self._postprocess_live(
invocation_context,
Expand All @@ -349,28 +356,11 @@ def get_author_for_event(llm_response):
)
) as agen:
async for event in agen:
if (
event.content
and event.content.parts
and event.content.parts[0].inline_data is None
and not event.partial
):
# This can be either user data or transcription data.
# when output transcription enabled, it will contain model's
# transcription.
# when input transcription enabled, it will contain user
# transcription.
if not invocation_context.transcription_cache:
invocation_context.transcription_cache = []
invocation_context.transcription_cache.append(
TranscriptionEntry(
role=event.content.role, data=event.content
)
)
# Cache output audio chunks from model responses
# TODO: support video data
if (
event.content
invocation_context.run_config.enable_saving_live_blob
and event.content
and event.content.parts
and event.content.parts[0].inline_data
and event.content.parts[0].inline_data.mime_type.startswith(
Expand Down Expand Up @@ -578,6 +568,36 @@ async def _postprocess_live(
):
return

# Handle transcription events ONCE per llm_response, outside the event loop
if llm_response.input_transcription:
input_transcription_event = (
await self.transcription_manager.handle_input_transcription(
invocation_context, llm_response.input_transcription
)
)
yield input_transcription_event
return

if llm_response.output_transcription:
output_transcription_event = (
await self.transcription_manager.handle_output_transcription(
invocation_context, llm_response.output_transcription
)
)
yield output_transcription_event
return

# Flush audio caches based on control events using configurable settings
if invocation_context.run_config.enable_saving_live_blob:
_handle_control_event_flush_event = (
await self._handle_control_event_flush(
invocation_context, llm_response
)
)
if _handle_control_event_flush_event:
yield _handle_control_event_flush_event
return

# Builds the event.
model_response_event = self._finalize_model_response_event(
llm_request, llm_response, model_response_event
Expand Down Expand Up @@ -871,33 +891,34 @@ async def _handle_control_event_flush(
invocation_context: The invocation context containing audio caches.
llm_response: The LLM response containing control event information.
"""

# Log cache statistics if enabled
if DEFAULT_ENABLE_CACHE_STATISTICS:
stats = self.audio_cache_manager.get_cache_stats(invocation_context)
logger.debug('Audio cache stats: %s', stats)

if llm_response.interrupted:
# user interrupts so the model will stop. we can flush model audio here
await self.audio_cache_manager.flush_caches(
return await self.audio_cache_manager.flush_caches(
invocation_context,
flush_user_audio=False,
flush_model_audio=True,
)
elif llm_response.turn_complete:
# turn completes so we can flush both user and model
await self.audio_cache_manager.flush_caches(
return await self.audio_cache_manager.flush_caches(
invocation_context,
flush_user_audio=True,
flush_model_audio=True,
)
elif getattr(llm_response, 'generation_complete', False):
# model generation complete so we can flush model audio
await self.audio_cache_manager.flush_caches(
return await self.audio_cache_manager.flush_caches(
invocation_context,
flush_user_audio=False,
flush_model_audio=True,
)

# Log cache statistics if enabled
if DEFAULT_ENABLE_CACHE_STATISTICS:
stats = self.audio_cache_manager.get_cache_stats(invocation_context)
logger.debug('Audio cache stats: %s', stats)

async def _run_and_handle_error(
self,
response_generator: AsyncGenerator[LlmResponse, None],
Expand Down
Loading
Loading