Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 34 additions & 27 deletions src/google/adk/agents/agent_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,53 +14,60 @@

from __future__ import annotations

from typing import Annotated
from typing import Any
from typing import get_args
from typing import Union

from pydantic import Discriminator
from pydantic import RootModel
from pydantic import Tag

from ..utils.feature_decorator import experimental
from .base_agent import BaseAgentConfig
from .base_agent_config import BaseAgentConfig
from .llm_agent_config import LlmAgentConfig
from .loop_agent_config import LoopAgentConfig
from .parallel_agent import ParallelAgentConfig
from .sequential_agent import SequentialAgentConfig
from .parallel_agent_config import ParallelAgentConfig
from .sequential_agent_config import SequentialAgentConfig

# A discriminated union of all possible agent configurations.
ConfigsUnion = Union[
LlmAgentConfig,
LoopAgentConfig,
ParallelAgentConfig,
SequentialAgentConfig,
BaseAgentConfig,
]
_ADK_AGENT_CLASSES: set[str] = {
"LlmAgent",
"LoopAgent",
"ParallelAgent",
"SequentialAgent",
}


def agent_config_discriminator(v: Any):
def agent_config_discriminator(v: Any) -> str:
"""Discriminator function that returns the tag name for Pydantic."""
if isinstance(v, dict):
agent_class = v.get("agent_class", "LlmAgent")
if agent_class in [
"LlmAgent",
"LoopAgent",
"ParallelAgent",
"SequentialAgent",
]:
agent_class: str = v.get("agent_class", "LlmAgent")

# Look up the agent_class in our dynamically built mapping
if agent_class in _ADK_AGENT_CLASSES:
return agent_class
else:
return "BaseAgent"

# For non ADK agent classes, use BaseAgent to handle it.
return "BaseAgent"

raise ValueError(f"Invalid agent config: {v}")


# A discriminated union of all possible agent configurations.
ConfigsUnion = Annotated[
Union[
Annotated[LlmAgentConfig, Tag("LlmAgent")],
Annotated[LoopAgentConfig, Tag("LoopAgent")],
Annotated[ParallelAgentConfig, Tag("ParallelAgent")],
Annotated[SequentialAgentConfig, Tag("SequentialAgent")],
Annotated[BaseAgentConfig, Tag("BaseAgent")],
],
Discriminator(agent_config_discriminator),
]


# Use a RootModel to represent the agent directly at the top level.
# The `discriminator` is applied to the union within the RootModel.
@experimental
class AgentConfig(RootModel[ConfigsUnion]):
"""The config for the YAML schema to create an agent."""

class Config:
# Pydantic v2 requires this for discriminated unions on RootModel
# This tells the model to look at the 'agent_class' field of the input
# data to decide which model from the `ConfigsUnion` to use.
discriminator = Discriminator(agent_config_discriminator)
88 changes: 57 additions & 31 deletions src/google/adk/agents/config_schemas/AgentConfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -681,12 +681,29 @@
"EnterpriseWebSearch": {
"additionalProperties": false,
"description": "Tool to search public web data, powered by Vertex AI Search and Sec4 compliance.",
"properties": {},
"properties": {
"excludeDomains": {
"anyOf": [
{
"items": {
"type": "string"
},
"type": "array"
},
{
"type": "null"
}
],
"default": null,
"description": "Optional. List of domains to be excluded from the search results. The default limit is 2000 domains.",
"title": "Excludedomains"
}
},
"title": "EnterpriseWebSearch",
"type": "object"
},
"Environment": {
"description": "Required. The environment being operated.",
"description": "The environment being operated.",
"enum": [
"ENVIRONMENT_UNSPECIFIED",
"ENVIRONMENT_BROWSER"
Expand Down Expand Up @@ -1475,31 +1492,31 @@
{
"$ref": "#/$defs/Content"
},
{
"type": "string"
},
{
"$ref": "#/$defs/File"
},
{
"$ref": "#/$defs/Part"
},
{
"items": {
"anyOf": [
{
"$ref": "#/$defs/File"
"type": "string"
},
{
"$ref": "#/$defs/Part"
"$ref": "#/$defs/File"
},
{
"type": "string"
"$ref": "#/$defs/Part"
}
]
},
"type": "array"
},
{
"$ref": "#/$defs/File"
},
{
"$ref": "#/$defs/Part"
},
{
"type": "string"
},
{
"type": "null"
}
Expand Down Expand Up @@ -1830,10 +1847,10 @@
"speechConfig": {
"anyOf": [
{
"$ref": "#/$defs/SpeechConfig"
"type": "string"
},
{
"type": "string"
"$ref": "#/$defs/SpeechConfig"
},
{
"type": "null"
Expand Down Expand Up @@ -1999,6 +2016,22 @@
],
"default": null,
"description": "Optional. Filter search results to a specific time range.\n If customers set a start time, they must set an end time (and vice versa).\n "
},
"excludeDomains": {
"anyOf": [
{
"items": {
"type": "string"
},
"type": "array"
},
{
"type": "null"
}
],
"default": null,
"description": "Optional. List of domains to be excluded from the search results.\n The default limit is 2000 domains.",
"title": "Excludedomains"
}
},
"title": "GoogleSearch",
Expand Down Expand Up @@ -2356,10 +2389,6 @@
"agent_class": {
"default": "LlmAgent",
"description": "The value is used to uniquely identify the LlmAgent class. If it is empty, it is by default an LlmAgent.",
"enum": [
"LlmAgent",
""
],
"title": "Agent Class",
"type": "string"
},
Expand Down Expand Up @@ -2618,7 +2647,6 @@
"description": "The config for the YAML schema of a LoopAgent.",
"properties": {
"agent_class": {
"const": "LoopAgent",
"default": "LoopAgent",
"description": "The value is used to uniquely identify the LoopAgent class.",
"title": "Agent Class",
Expand Down Expand Up @@ -2774,7 +2802,6 @@
"description": "The config for the YAML schema of a ParallelAgent.",
"properties": {
"agent_class": {
"const": "ParallelAgent",
"default": "ParallelAgent",
"description": "The value is used to uniquely identify the ParallelAgent class.",
"title": "Agent Class",
Expand Down Expand Up @@ -3680,7 +3707,6 @@
"description": "The config for the YAML schema of a SequentialAgent.",
"properties": {
"agent_class": {
"const": "SequentialAgent",
"default": "SequentialAgent",
"description": "The value is used to uniquely identify the SequentialAgent class.",
"title": "Agent Class",
Expand Down Expand Up @@ -4413,29 +4439,29 @@
"default": null,
"description": "Optional. Tool to support URL context retrieval."
},
"codeExecution": {
"computerUse": {
"anyOf": [
{
"$ref": "#/$defs/ToolCodeExecution"
"$ref": "#/$defs/ToolComputerUse"
},
{
"type": "null"
}
],
"default": null,
"description": "Optional. CodeExecution tool type. Enables the model to execute code as part of generation."
"description": "Optional. Tool to support the model interacting directly with the\n computer. If enabled, it automatically populates computer-use specific\n Function Declarations."
},
"computerUse": {
"codeExecution": {
"anyOf": [
{
"$ref": "#/$defs/ToolComputerUse"
"$ref": "#/$defs/ToolCodeExecution"
},
{
"type": "null"
}
],
"default": null,
"description": "Optional. Tool to support the model interacting directly with the computer. If enabled, it automatically populates computer-use specific Function Declarations."
"description": "Optional. CodeExecution tool type. Enables the model to execute code as part of generation."
}
},
"title": "Tool",
Expand Down Expand Up @@ -4556,7 +4582,8 @@
"type": "object"
}
},
"anyOf": [
"description": "The config for the YAML schema to create an agent.",
"oneOf": [
{
"$ref": "#/$defs/LlmAgentConfig"
},
Expand All @@ -4573,6 +4600,5 @@
"$ref": "#/$defs/BaseAgentConfig"
}
],
"description": "The config for the YAML schema to create an agent.",
"title": "AgentConfig"
}
2 changes: 1 addition & 1 deletion src/google/adk/agents/llm_agent_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class LlmAgentConfig(BaseAgentConfig):
extra='forbid',
)

agent_class: Literal['LlmAgent', ''] = Field(
agent_class: str = Field(
default='LlmAgent',
description=(
'The value is used to uniquely identify the LlmAgent class. If it is'
Expand Down
3 changes: 1 addition & 2 deletions src/google/adk/agents/loop_agent_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from __future__ import annotations

from typing import Literal
from typing import Optional

from pydantic import ConfigDict
Expand All @@ -34,7 +33,7 @@ class LoopAgentConfig(BaseAgentConfig):
extra='forbid',
)

agent_class: Literal['LoopAgent'] = Field(
agent_class: str = Field(
default='LoopAgent',
description='The value is used to uniquely identify the LoopAgent class.',
)
Expand Down
10 changes: 4 additions & 6 deletions src/google/adk/agents/parallel_agent_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

from __future__ import annotations

from typing import Literal

from pydantic import ConfigDict
from pydantic import Field

Expand All @@ -30,12 +28,12 @@ class ParallelAgentConfig(BaseAgentConfig):
"""The config for the YAML schema of a ParallelAgent."""

model_config = ConfigDict(
extra='forbid',
extra="forbid",
)

agent_class: Literal['ParallelAgent'] = Field(
default='ParallelAgent',
agent_class: str = Field(
default="ParallelAgent",
description=(
'The value is used to uniquely identify the ParallelAgent class.'
"The value is used to uniquely identify the ParallelAgent class."
),
)
10 changes: 4 additions & 6 deletions src/google/adk/agents/sequential_agent_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

from __future__ import annotations

from typing import Literal

from pydantic import ConfigDict
from pydantic import Field

Expand All @@ -30,12 +28,12 @@ class SequentialAgentConfig(BaseAgentConfig):
"""The config for the YAML schema of a SequentialAgent."""

model_config = ConfigDict(
extra='forbid',
extra="forbid",
)

agent_class: Literal['SequentialAgent'] = Field(
default='SequentialAgent',
agent_class: str = Field(
default="SequentialAgent",
description=(
'The value is used to uniquely identify the SequentialAgent class.'
"The value is used to uniquely identify the SequentialAgent class."
),
)
Loading
Loading