Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion contributing/samples/gepa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib

import utils


Expand Down
1 change: 0 additions & 1 deletion contributing/samples/gepa/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from absl import flags
import experiment
from google.genai import types

import utils

_OUTPUT_DIR = flags.DEFINE_string(
Expand Down
10 changes: 5 additions & 5 deletions src/google/adk/agents/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator
from pydantic import TypeAdapter
from typing_extensions import override
from typing_extensions import TypeAlias

Expand All @@ -48,6 +49,7 @@
from ..models.llm_response import LlmResponse
from ..models.registry import LLMRegistry
from ..planners.base_planner import BasePlanner
from ..tools._gemini_schema_util import validate_and_dump_schema
from ..tools.base_tool import BaseTool
from ..tools.base_toolset import BaseToolset
from ..tools.function_tool import FunctionTool
Expand Down Expand Up @@ -314,9 +316,9 @@ class LlmAgent(BaseAgent):
"""

# Controlled input/output configurations - Start
input_schema: Optional[type[BaseModel]] = None
input_schema: Optional[Any] = None
"""The input schema when agent is used as a tool."""
output_schema: Optional[type[BaseModel]] = None
output_schema: Optional[Any] = None
"""The output schema when agent replies.

NOTE:
Expand Down Expand Up @@ -833,9 +835,7 @@ def __maybe_save_output_to_state(self, event: Event):
# Do not attempt to parse it as JSON.
if not result.strip():
return
result = self.output_schema.model_validate_json(result).model_dump(
exclude_none=True
)
result = validate_and_dump_schema(self.output_schema, result)
event.actions.state_delta[self.output_key] = result

@model_validator(mode='after')
Expand Down
34 changes: 0 additions & 34 deletions src/google/adk/tools/_function_parameter_parse_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,31 +140,6 @@ def _is_builtin_primitive_or_compound(
return annotation in _py_builtin_type_to_schema_type.keys()


def _raise_for_any_of_if_mldev(schema: types.Schema):
if schema.any_of:
raise ValueError(
'AnyOf is not supported in function declaration schema for Google AI.'
)


def _update_for_default_if_mldev(schema: types.Schema):
if schema.default is not None:
# TODO(kech): Remove this workaround once mldev supports default value.
schema.default = None
logger.warning(
'Default value is not supported in function declaration schema for'
' Google AI.'
)


def _raise_if_schema_unsupported(
variant: GoogleLLMVariant, schema: types.Schema
):
if variant == GoogleLLMVariant.GEMINI_API:
_raise_for_any_of_if_mldev(schema)
# _update_for_default_if_mldev(schema) # No need of this since GEMINI now supports default value


def _is_default_value_compatible(
default_value: Any, annotation: inspect.Parameter.annotation
) -> bool:
Expand Down Expand Up @@ -230,7 +205,6 @@ def _parse_schema_from_parameter(
raise ValueError(default_value_error_msg)
schema.default = param.default
schema.type = _py_builtin_type_to_schema_type[param.annotation]
_raise_if_schema_unsupported(variant, schema)
return schema
if isinstance(param.annotation, type) and issubclass(param.annotation, Enum):
schema.type = types.Type.STRING
Expand All @@ -244,7 +218,6 @@ def _parse_schema_from_parameter(
if default_value not in schema.enum:
raise ValueError(default_value_error_msg)
schema.default = default_value
_raise_if_schema_unsupported(variant, schema)
return schema
if (
get_origin(param.annotation) is Union
Expand Down Expand Up @@ -285,7 +258,6 @@ def _parse_schema_from_parameter(
if not _is_default_value_compatible(param.default, param.annotation):
raise ValueError(default_value_error_msg)
schema.default = param.default
_raise_if_schema_unsupported(variant, schema)
return schema
if isinstance(param.annotation, _GenericAlias) or isinstance(
param.annotation, typing_types.GenericAlias
Expand All @@ -298,7 +270,6 @@ def _parse_schema_from_parameter(
if not _is_default_value_compatible(param.default, param.annotation):
raise ValueError(default_value_error_msg)
schema.default = param.default
_raise_if_schema_unsupported(variant, schema)
return schema
if origin is Literal:
if not all(isinstance(arg, str) for arg in args):
Expand All @@ -311,7 +282,6 @@ def _parse_schema_from_parameter(
if not _is_default_value_compatible(param.default, param.annotation):
raise ValueError(default_value_error_msg)
schema.default = param.default
_raise_if_schema_unsupported(variant, schema)
return schema
if origin is list:
schema.type = types.Type.ARRAY
Expand All @@ -328,7 +298,6 @@ def _parse_schema_from_parameter(
if not _is_default_value_compatible(param.default, param.annotation):
raise ValueError(default_value_error_msg)
schema.default = param.default
_raise_if_schema_unsupported(variant, schema)
return schema
if origin is Union:
schema.any_of = []
Expand Down Expand Up @@ -374,7 +343,6 @@ def _parse_schema_from_parameter(
if not _is_default_value_compatible(param.default, param.annotation):
raise ValueError(default_value_error_msg)
schema.default = param.default
_raise_if_schema_unsupported(variant, schema)
return schema
# all other generic alias will be invoked in raise branch
if (
Expand All @@ -399,7 +367,6 @@ def _parse_schema_from_parameter(
),
func_name,
)
_raise_if_schema_unsupported(variant, schema)
return schema
if inspect.isclass(param.annotation) and issubclass(
param.annotation, ToolContext
Expand All @@ -413,7 +380,6 @@ def _parse_schema_from_parameter(
# null is not a valid type in schema, use object instead.
schema.type = types.Type.OBJECT
schema.nullable = True
_raise_if_schema_unsupported(variant, schema)
return schema
raise ValueError(
f'Failed to parse the parameter {param} of function {func_name} for'
Expand Down
10 changes: 10 additions & 0 deletions src/google/adk/tools/_gemini_schema_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

from google.genai.types import JSONSchema
from google.genai.types import Schema
from pydantic import BaseModel
from pydantic import Field
from pydantic import TypeAdapter

from ..utils.variant_utils import get_google_llm_variant

Expand Down Expand Up @@ -208,3 +210,11 @@ def _to_gemini_schema(openapi_schema: dict[str, Any]) -> Schema:
json_schema=_ExtendedJSONSchema.model_validate(sanitized_schema),
api_option=get_google_llm_variant(),
)


def validate_and_dump_schema(schema: Any, json_data: str) -> Any:
"""Validates json data against a schema and returns a serializable object."""
validated_result = TypeAdapter(schema).validate_json(json_data)
if isinstance(validated_result, BaseModel):
return validated_result.model_dump(exclude_none=True)
return validated_result
21 changes: 12 additions & 9 deletions src/google/adk/tools/agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from google.genai import types
from pydantic import model_validator
from pydantic import TypeAdapter
from typing_extensions import override

from . import _automatic_function_calling_util
Expand All @@ -28,6 +29,7 @@
from ..memory.in_memory_memory_service import InMemoryMemoryService
from ..utils.context_utils import Aclosing
from ._forwarding_artifact_service import ForwardingArtifactService
from ._gemini_schema_util import validate_and_dump_schema
from .base_tool import BaseTool
from .tool_configs import BaseToolConfig
from .tool_configs import ToolArgsConfig
Expand Down Expand Up @@ -145,14 +147,15 @@ async def run_async(
tool_context.actions.skip_summarization = True

if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
input_value = self.agent.input_schema.model_validate(args)
input_value = TypeAdapter(self.agent.input_schema).validate_python(args)
text = (
TypeAdapter(self.agent.input_schema)
.dump_json(input_value, exclude_none=True)
.decode('utf-8')
)
content = types.Content(
role='user',
parts=[
types.Part.from_text(
text=input_value.model_dump_json(exclude_none=True)
)
],
parts=[types.Part.from_text(text=text)],
)
else:
content = types.Content(
Expand Down Expand Up @@ -213,9 +216,9 @@ async def run_async(
p.text for p in last_content.parts if p.text and not p.thought
)
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
tool_result = self.agent.output_schema.model_validate_json(
merged_text
).model_dump(exclude_none=True)
tool_result = validate_and_dump_schema(
self.agent.output_schema, merged_text
)
else:
tool_result = merged_text
return tool_result
Expand Down
55 changes: 55 additions & 0 deletions tests/unittests/agents/test_llm_agent_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@

import logging
from typing import Any
from typing import cast
from typing import Literal
from typing import Optional
from typing import Union
from unittest import mock

from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.readonly_context import ReadonlyContext
from google.adk.events.event import Event
from google.adk.models.anthropic_llm import Claude
from google.adk.models.google_llm import Gemini
from google.adk.models.lite_llm import LiteLlm
Expand All @@ -34,9 +38,12 @@
from google.adk.tools.google_search_tool import GoogleSearchTool
from google.adk.tools.vertex_ai_search_tool import VertexAiSearchTool
from google.genai import types
from google.genai.types import Part
from pydantic import BaseModel
import pytest

from .. import testing_utils


async def _create_readonly_context(
agent: LlmAgent, state: Optional[dict[str, Any]] = None
Expand Down Expand Up @@ -519,3 +526,51 @@ def test_builtin_planner_overwrite_logging(caplog):
'Overwriting `thinking_config` from `generate_content_config`'
in caplog.text
)


def test_output_schema_with_union():
"""Tests if agent can have a Union type in output_schema."""

class CustomOutput1(BaseModel):
custom_output1: str

class CustomOutput2(BaseModel):
custom_output2: str

agent = LlmAgent(
name='test_agent',
output_schema=Union[CustomOutput1, CustomOutput2, Literal['option3']],
output_key='test_output',
)

# Test with the first type
event1 = Event(
author='test_agent',
content=types.Content(
parts=[Part(text='{"custom_output1": "response1"}')]
),
)
agent._LlmAgent__maybe_save_output_to_state(event1)
assert event1.actions.state_delta['test_output'] == {
'custom_output1': 'response1'
}

# Test with the second type
event2 = Event(
author='test_agent',
content=types.Content(
parts=[Part(text='{"custom_output2": "response2"}')]
),
)
agent._LlmAgent__maybe_save_output_to_state(event2)
assert event2.actions.state_delta['test_output'] == {
'custom_output2': 'response2'
}

# Test with the literal type
event3 = Event(
author='test_agent',
content=types.Content(parts=[Part(text='"option3"')]),
)
agent._LlmAgent__maybe_save_output_to_state(event3)
assert event3.actions.state_delta['test_output'] == 'option3'
Loading