Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.orm.attributes import flag_modified
from sqlalchemy.schema import MetaData
from sqlalchemy.types import DateTime
from sqlalchemy.types import PickleType
Expand Down Expand Up @@ -335,11 +336,20 @@ def from_event(cls, session: Session, event: Event) -> StorageEvent:
)
if event.custom_metadata:
storage_event.custom_metadata = event.custom_metadata
if event.usage_metadata:
storage_event.usage_metadata = event.usage_metadata.model_dump(
exclude_none=True, mode="json"
)
if event.citation_metadata:

if hasattr(event, "usage_metadata") and event.usage_metadata is not None:
try:
usage_meta = event.usage_metadata
if hasattr(usage_meta, "model_dump"):
storage_event.usage_metadata = usage_meta.model_dump(
exclude_none=False, mode="json"
)
except Exception as e:
logger.error(
f"[StorageEvent.from_event] Error while saving usage_metadata: {e}"
)

if hasattr(event, "citation_metadata") and event.citation_metadata:
storage_event.citation_metadata = event.citation_metadata.model_dump(
exclude_none=True, mode="json"
)
Expand Down Expand Up @@ -727,7 +737,15 @@ async def append_event(self, session: Session, event: Event) -> Event:
else:
update_time = datetime.fromtimestamp(event.timestamp)
storage_session.update_time = update_time
sql_session.add(StorageEvent.from_event(session, event))
storage_event = StorageEvent.from_event(session, event)

sql_session.add(storage_event)

# Forçar SQLAlchemy a detectar mudanças em campos MutableDict/DynamicJSON
if storage_event.usage_metadata is not None:
flag_modified(storage_event, "usage_metadata")
if storage_event.citation_metadata is not None:
flag_modified(storage_event, "citation_metadata")

await sql_session.commit()
await sql_session.refresh(storage_session)
Expand Down
73 changes: 64 additions & 9 deletions src/google/adk/tools/agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from __future__ import annotations

import json
import logging
from typing import Any
from typing import TYPE_CHECKING

Expand All @@ -23,7 +25,9 @@

from . import _automatic_function_calling_util
from ..agents.common_configs import AgentRefConfig
from ..events.event import Event
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Event and Session are not used, no?

from ..memory.in_memory_memory_service import InMemoryMemoryService
from ..sessions import Session
from ..utils.context_utils import Aclosing
from ._forwarding_artifact_service import ForwardingArtifactService
from .base_tool import BaseTool
Expand All @@ -34,6 +38,8 @@
if TYPE_CHECKING:
from ..agents.base_agent import BaseAgent

logger = logging.getLogger(__name__)


class AgentTool(BaseTool):
"""A tool that wraps an agent.
Expand Down Expand Up @@ -136,7 +142,11 @@ async def run_async(
else:
content = types.Content(
role='user',
parts=[types.Part.from_text(text=args['request'])],
parts=[
types.Part.from_text(
text=str(args) if isinstance(args, str) else json.dumps(args)
)
],
)
invocation_context = tool_context._invocation_context
parent_app_name = (
Expand All @@ -161,40 +171,85 @@ async def run_async(
state_dict = {
k: v
for k, v in tool_context.state.to_dict().items()
if not k.startswith('_adk') # Filter out adk internal states
if not k.startswith('_adk')
}
session = await runner.session_service.create_session(
sub_agent_session = await runner.session_service.create_session(
app_name=child_app_name,
user_id=tool_context._invocation_context.user_id,
state=state_dict,
)

last_content = None
# Collect all text chunks from streaming response instead of just last content
chunks: list[str] = []
sub_agent_events = []

def iter_text_parts(parts):
"""Safely iterate over parts and extract text, skipping None values."""
for p in parts or []:
if hasattr(p, 'text') and p.text is not None:
yield p.text

async with Aclosing(
runner.run_async(
user_id=session.user_id, session_id=session.id, new_message=content
user_id=sub_agent_session.user_id,
session_id=sub_agent_session.id,
new_message=content,
)
) as agen:
async for event in agen:
# Forward state delta to parent session.
if event.actions.state_delta:
tool_context.state.update(event.actions.state_delta)
# Collect text chunks from all events, not just the last one
if event.content:
last_content = event.content
chunks.extend(iter_text_parts(event.content.parts))
sub_agent_events.append(event)

if sub_agent_events and hasattr(tool_context, '_invocation_context'):
main_session = tool_context._invocation_context.session
if main_session and hasattr(
tool_context._invocation_context, 'session_service'
):
session_service = tool_context._invocation_context.session_service
parent_agent_name = (
tool_context._invocation_context.agent.name
if hasattr(tool_context._invocation_context, 'agent')
else 'root_agent'
)

for sub_event in sub_agent_events:

try:
if hasattr(sub_event, 'branch') and sub_event.branch:
event_branch = sub_event.branch
else:
event_branch = f'{parent_agent_name}.{self.agent.name}'

copied_event = sub_event.model_copy(update={'branch': event_branch})

await session_service.append_event(main_session, copied_event)
except Exception as e:
logger.warning(
"Error copying sub-agent event from '%s' to main session: %s",
self.agent.name,
e,
)

# Clean up runner resources (especially MCP sessions)
# to avoid "Attempted to exit cancel scope in a different task" errors
await runner.close()

if not last_content:
# Merge all collected chunks into final text
merged_text = ''.join(chunks)

if not merged_text:
return ''
merged_text = '\n'.join(p.text for p in last_content.parts if p.text)
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
tool_result = self.agent.output_schema.model_validate_json(
merged_text
).model_dump(exclude_none=True)
else:
tool_result = merged_text

return tool_result

@override
Expand Down
Loading
Loading