Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 9 additions & 23 deletions src/google/adk/agents/sequential_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Sequential agent implementation."""


from __future__ import annotations

import logging
Expand Down Expand Up @@ -121,38 +122,23 @@ async def _run_live_impl(
) -> AsyncGenerator[Event, None]:
"""Implementation for live SequentialAgent.

Compared to the non-live case, live agents process a continuous stream of audio
or video, so there is no way to tell if it's finished and should pass
to the next agent or not. So we introduce a task_completed() function so the
model can call this function to signal that it's finished the task and we
can move on to the next agent.
In a live run, this agent executes its sub-agents one by one. It relies
on the `generation_complete` event from the underlying model to determine
when a sub-agent has finished its turn. Once a sub-agent's `run_live`
stream concludes (triggered by the `generation_complete` event), the
`SequentialAgent` will proceed to execute the next sub-agent in the
sequence.

Args:
ctx: The invocation context of the agent.
"""
if not self.sub_agents:
return

# There is no way to know if it's using live during init phase so we have to init it here
for sub_agent in self.sub_agents:
# add tool
def task_completed():
"""
Signals that the agent has successfully completed the user's question
or task.
"""
return 'Task completion signaled.'

if isinstance(sub_agent, LlmAgent):
# Use function name to dedupe.
if task_completed.__name__ not in sub_agent.tools:
sub_agent.tools.append(task_completed)
sub_agent.instruction += f"""If you finished the user's request
according to its description, call the {task_completed.__name__} function
to exit so the next agents can take over. When calling this function,
do not generate any text other than the function call."""

for sub_agent in self.sub_agents:
async with Aclosing(sub_agent.run_live(ctx)) as agen:
async for event in agen:
yield event
if event.generation_complete:
break
191 changes: 126 additions & 65 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import asyncio
import datetime
import inspect
import time
import logging
from typing import AsyncGenerator
from typing import cast
Expand Down Expand Up @@ -137,64 +138,77 @@ async def run_live(
invocation_context, event_id, llm_request.contents
)

send_task = asyncio.create_task(
self._send_to_model(llm_connection, invocation_context)
)
event_queue = asyncio.Queue()

async def send_handler():
"""Handles sending user input and generating user text events."""
async for event in self._send_to_model(
llm_connection, invocation_context
):
await event_queue.put(event)

async def receive_handler():
"""Handles receiving model output and generating model events."""
try:
async for event in self._receive_from_model(
llm_connection,
event_id,
invocation_context,
llm_request,
):
await event_queue.put(event)
finally:
# Signal that the receiving process is complete.
await event_queue.put(None)

send_task = asyncio.create_task(send_handler())
receive_task = asyncio.create_task(receive_handler())
tasks = {send_task, receive_task}

try:
async with Aclosing(
self._receive_from_model(
llm_connection,
event_id,
invocation_context,
llm_request,
while True:
# Consume events from the unified queue.
event = await event_queue.get()
if event is None: # End of stream signal
break

logger.debug('Receive new event: %s', event)
yield event

# Forward function responses back to the model.
if event.get_function_responses():
logger.debug(
'Sending back last function response event: %s', event
)
) as agen:
async for event in agen:
# Empty event means the queue is closed.
if not event:
break
logger.debug('Receive new event: %s', event)
yield event
# send back the function response
if event.get_function_responses():
logger.debug(
'Sending back last function response event: %s', event
)
invocation_context.live_request_queue.send_content(
event.content
)
if (
event.content
and event.content.parts
and event.content.parts[0].function_response
and event.content.parts[0].function_response.name
== 'transfer_to_agent'
):
await asyncio.sleep(DEFAULT_TRANSFER_AGENT_DELAY)
# cancel the tasks that belongs to the closed connection.
send_task.cancel()
await llm_connection.close()
if (
invocation_context.live_request_queue.send_content(
event.content
and event.content.parts
and event.content.parts[0].function_response
and event.content.parts[0].function_response.name
== 'task_completed'
):
# this is used for sequential agent to signal the end of the agent.
await asyncio.sleep(DEFAULT_TASK_COMPLETION_DELAY)
# cancel the tasks that belongs to the closed connection.
send_task.cancel()
return
)

if (
event.content
and event.content.parts
and event.content.parts[0].function_response
and event.content.parts[0].function_response.name
== 'transfer_to_agent'
):
await asyncio.sleep(DEFAULT_TRANSFER_AGENT_DELAY)
# cancel the tasks that belongs to the closed connection.
send_task.cancel()
await llm_connection.close()
if (
event.generation_complete
):
# this is used for sequential agent to signal the end of the agent.
await asyncio.sleep(DEFAULT_TASK_COMPLETION_DELAY)
# cancel the tasks that belongs to the closed connection.
send_task.cancel()
return
finally:
# Clean up
if not send_task.done():
send_task.cancel()
try:
await send_task
except asyncio.CancelledError:
pass
# Clean up all running tasks.
for task in tasks:
if not task.done():
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
except (ConnectionClosed, ConnectionClosedOK) as e:
# when the session timeout, it will just close and not throw exception.
# so this is for bad cases
Expand All @@ -210,8 +224,8 @@ async def _send_to_model(
self,
llm_connection: BaseLlmConnection,
invocation_context: InvocationContext,
):
"""Sends data to model."""
) -> AsyncGenerator[Event, None]:
"""Sends data to model and yields user events for text messages."""
while True:
live_request_queue = invocation_context.live_request_queue
try:
Expand Down Expand Up @@ -263,7 +277,24 @@ async def _send_to_model(

await llm_connection.send_realtime(live_request.blob)

if live_request.content:
# If the request is a user-sent text message, create and yield an event
# so it can be saved to the session history.
if (
live_request.content
and live_request.content.parts
and live_request.content.parts[0].text
):
user_event = Event(
invocation_id=invocation_context.invocation_id,
author='user',
content=live_request.content,
timestamp=time.time(),
)
yield user_event
await llm_connection.send_content(live_request.content)
elif live_request.content:
# Handle other content types, like function responses, without creating
# a user event.
await llm_connection.send_content(live_request.content)

async def _receive_from_model(
Expand Down Expand Up @@ -383,16 +414,30 @@ async def _run_one_step_async(
events = invocation_context._get_events(
current_invocation=True, current_branch=True
)

# Long running tool calls should have been handled before this point.
# If there are still long running tool calls, it means the agent is paused
# before, and its branch hasn't been resumed yet.
if (
invocation_context.is_resumable
and events
and len(events) > 1
# TODO: here we are using the last 2 events to decide whether to pause
# the invocation. But this is just being optmisitic, we should find a
# way to pause when the long running tool call is followed by more than
# one text responses.
and (
invocation_context.should_pause_invocation(events[-1])
or invocation_context.should_pause_invocation(events[-2])
)
):
return

if (
invocation_context.is_resumable
and events
and events[-1].get_function_calls()
):
# Long running tool calls should have been handled before this point.
# If there are still long running tool calls, it means the agent is paused
# before, and its branch hasn't been resumed yet.
if invocation_context.should_pause_invocation(events[-1]):
return
model_response_event = events[-1]
async with Aclosing(
self._postprocess_handle_function_calls_async(
Expand Down Expand Up @@ -438,6 +483,10 @@ async def _preprocess_async(
from ...agents.llm_agent import LlmAgent

agent = invocation_context.agent
if not isinstance(agent, LlmAgent):
raise TypeError(
f'Expected agent to be an LlmAgent, but got {type(agent)}'
)

# Runs processors.
for processor in self.request_processors:
Expand Down Expand Up @@ -468,7 +517,7 @@ async def _preprocess_async(
tools = await _convert_tool_union_to_tools(
tool_union,
ReadonlyContext(invocation_context),
llm_request.model,
agent.model,
multiple_tools,
)
for tool in tools:
Expand Down Expand Up @@ -563,14 +612,26 @@ async def _postprocess_live(
and not llm_response.turn_complete
and not llm_response.input_transcription
and not llm_response.output_transcription
and not llm_response.generation_complete
):
return

# Handle transcription events ONCE per llm_response, outside the event loop
if llm_response.generation_complete:
yield Event(
id=Event.new_id(),
invocation_id=invocation_context.invocation_id,
author="model",
generation_complete = llm_response.generation_complete,
partial = llm_response.partial,
timestamp=time.time(),
)
return

if llm_response.input_transcription:
input_transcription_event = (
await self.transcription_manager.handle_input_transcription(
invocation_context, llm_response.input_transcription
invocation_context, llm_response
)
)
yield input_transcription_event
Expand All @@ -579,7 +640,7 @@ async def _postprocess_live(
if llm_response.output_transcription:
output_transcription_event = (
await self.transcription_manager.handle_output_transcription(
invocation_context, llm_response.output_transcription
invocation_context, llm_response
)
)
yield output_transcription_event
Expand Down Expand Up @@ -981,4 +1042,4 @@ async def _run_and_handle_error(
def __get_llm(self, invocation_context: InvocationContext) -> BaseLlm:
from ...agents.llm_agent import LlmAgent

return cast(LlmAgent, invocation_context.agent).canonical_model
return cast(LlmAgent, invocation_context.agent).canonical_model
Loading