Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions contributing/samples/human_tool_confirmation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from . import agent
80 changes: 80 additions & 0 deletions contributing/samples/human_tool_confirmation/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.adk import Agent
from google.adk.tools.function_tool import FunctionTool
from google.adk.tools.tool_confirmation import ToolConfirmation
from google.adk.tools.tool_context import ToolContext
from google.genai import types


def reimburse(amount: int, tool_context: ToolContext) -> str:
"""Reimburse the employee for the given amount."""
return {'status': 'ok'}


def request_time_off(days: int, tool_context: ToolContext):
"""Request day off for the employee."""
if days <= 0:
return {'status': 'Invalid days to request.'}

if days <= 2:
return {
'status': 'ok',
'approved_days': days,
}

tool_confirmation = tool_context.tool_confirmation
if not tool_confirmation:
tool_context.request_confirmation(
hint=(
'Please approve or reject the tool call request_time_off() by'
' responding with a FunctionResponse with an expected'
' ToolConfirmation payload.'
),
payload={
'approved_days': 0,
},
)
return {'status': 'Manager approval is required.'}

approved_days = tool_confirmation.payload['approved_days']
approved_days = min(approved_days, days)
if approved_days == 0:
return {'status': 'The time off request is rejected.', 'approved_days': 0}
return {
'status': 'ok',
'approved_days': approved_days,
}


root_agent = Agent(
model='gemini-2.5-flash',
name='time_off_agent',
instruction="""
You are a helpful assistant that can help employees with reimbursement and time off requests.
- Use the `reimburse` tool for reimbursement requests.
- Use the `request_time_off` tool for time off requests.
- Prioritize using tools to fulfill the user's request.
- Always respond to the user with the tool results.
""",
tools=[
# Set require_confirmation to True to require user confirmation for the
# tool call. This is an easier way to get user confirmation if the tool
# just need a boolean confirmation.
FunctionTool(reimburse, require_confirmation=True),
request_time_off,
],
generate_content_config=types.GenerateContentConfig(temperature=0.1),
)
8 changes: 8 additions & 0 deletions src/google/adk/events/event_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

from typing import Any
from typing import Optional

from pydantic import alias_generators
Expand All @@ -22,6 +23,7 @@
from pydantic import Field

from ..auth.auth_tool import AuthConfig
from ..tools.tool_confirmation import ToolConfirmation


class EventActions(BaseModel):
Expand Down Expand Up @@ -64,3 +66,9 @@ class EventActions(BaseModel):
identify the function call.
- Values: The requested auth config.
"""

requested_tool_confirmations: dict[str, ToolConfirmation] = Field(
default_factory=dict
)
"""A dict of tool confirmation requested by this event, keyed by
function call id."""
1 change: 1 addition & 0 deletions src/google/adk/flows/llm_flows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from . import functions
from . import identity
from . import instructions
from . import request_confirmation
6 changes: 6 additions & 0 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,12 @@ async def _postprocess_handle_function_calls_async(
if auth_event:
yield auth_event

tool_confirmation_event = functions.generate_request_confirmation_event(
invocation_context, function_call_event, function_response_event
)
if tool_confirmation_event:
yield tool_confirmation_event

# Always yield the function response event first
yield function_response_event

Expand Down
29 changes: 19 additions & 10 deletions src/google/adk/flows/llm_flows/contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ...models.llm_request import LlmRequest
from ._base_llm_processor import BaseLlmRequestProcessor
from .functions import remove_client_function_call_id
from .functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME
from .functions import REQUEST_EUC_FUNCTION_CALL_NAME


Expand Down Expand Up @@ -238,6 +239,9 @@ def _get_contents(
if _is_auth_event(event):
# Skip auth events.
continue
if _is_request_confirmation_event(event):
# Skip request confirmation events.
continue
filtered_events.append(
_convert_foreign_event(event)
if _is_other_agent_reply(agent_name, event)
Expand Down Expand Up @@ -431,18 +435,23 @@ def _is_event_belongs_to_branch(
return invocation_branch.startswith(event.branch)


def _is_auth_event(event: Event) -> bool:
if not event.content.parts:
def _is_function_call_event(event: Event, function_name: str) -> bool:
"""Checks if an event is a function call/response for a given function name."""
if not event.content or not event.content.parts:
return False
for part in event.content.parts:
if (
part.function_call
and part.function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME
):
if part.function_call and part.function_call.name == function_name:
return True
if (
part.function_response
and part.function_response.name == REQUEST_EUC_FUNCTION_CALL_NAME
):
if part.function_response and part.function_response.name == function_name:
return True
return False


def _is_auth_event(event: Event) -> bool:
"""Checks if the event is an authentication event."""
return _is_function_call_event(event, REQUEST_EUC_FUNCTION_CALL_NAME)


def _is_request_confirmation_event(event: Event) -> bool:
"""Checks if the event is a request confirmation event."""
return _is_function_call_event(event, REQUEST_CONFIRMATION_FUNCTION_CALL_NAME)
76 changes: 74 additions & 2 deletions src/google/adk/flows/llm_flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ...telemetry import trace_tool_call
from ...telemetry import tracer
from ...tools.base_tool import BaseTool
from ...tools.tool_confirmation import ToolConfirmation
from ...tools.tool_context import ToolContext
from ...utils.context_utils import Aclosing

Expand All @@ -47,6 +48,7 @@

AF_FUNCTION_CALL_ID_PREFIX = 'adk-'
REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential'
REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = 'adk_request_confirmation'

logger = logging.getLogger('google_adk.' + __name__)

Expand Down Expand Up @@ -130,11 +132,76 @@ def generate_auth_event(
)


def generate_request_confirmation_event(
invocation_context: InvocationContext,
function_call_event: Event,
function_response_event: Event,
) -> Optional[Event]:
"""Generates a request confirmation event from a function response event."""
if not function_response_event.actions.requested_tool_confirmations:
return None
parts = []
long_running_tool_ids = set()
function_calls = function_call_event.get_function_calls()
for (
function_call_id,
tool_confirmation,
) in function_response_event.actions.requested_tool_confirmations.items():
original_function_call = next(
(fc for fc in function_calls if fc.id == function_call_id), None
)
if not original_function_call:
continue
request_confirmation_function_call = types.FunctionCall(
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
args={
'originalFunctionCall': original_function_call.model_dump(
exclude_none=True, by_alias=True
),
'toolConfirmation': tool_confirmation.model_dump(
by_alias=True, exclude_none=True
),
},
)
request_confirmation_function_call.id = generate_client_function_call_id()
long_running_tool_ids.add(request_confirmation_function_call.id)
parts.append(types.Part(function_call=request_confirmation_function_call))

return Event(
invocation_id=invocation_context.invocation_id,
author=invocation_context.agent.name,
branch=invocation_context.branch,
content=types.Content(
parts=parts, role=function_response_event.content.role
),
long_running_tool_ids=long_running_tool_ids,
)


async def handle_function_calls_async(
invocation_context: InvocationContext,
function_call_event: Event,
tools_dict: dict[str, BaseTool],
filters: Optional[set[str]] = None,
tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None,
) -> Optional[Event]:
"""Calls the functions and returns the function response event."""
function_calls = function_call_event.get_function_calls()
return await handle_function_call_list_async(
invocation_context,
function_calls,
tools_dict,
filters,
tool_confirmation_dict,
)


async def handle_function_call_list_async(
invocation_context: InvocationContext,
function_calls: list[types.FunctionCall],
tools_dict: dict[str, BaseTool],
filters: Optional[set[str]] = None,
tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None,
) -> Optional[Event]:
"""Calls the functions and returns the function response event."""
from ...agents.llm_agent import LlmAgent
Expand All @@ -143,8 +210,6 @@ async def handle_function_calls_async(
if not isinstance(agent, LlmAgent):
return None

function_calls = function_call_event.get_function_calls()

# Filter function calls
filtered_calls = [
fc for fc in function_calls if not filters or fc.id in filters
Expand All @@ -161,6 +226,9 @@ async def handle_function_calls_async(
function_call,
tools_dict,
agent,
tool_confirmation_dict[function_call.id]
if tool_confirmation_dict
else None,
)
)
for function_call in filtered_calls
Expand Down Expand Up @@ -198,12 +266,14 @@ async def _execute_single_function_call_async(
function_call: types.FunctionCall,
tools_dict: dict[str, BaseTool],
agent: LlmAgent,
tool_confirmation: Optional[ToolConfirmation] = None,
) -> Optional[Event]:
"""Execute a single function call with thread safety for state modifications."""
tool, tool_context = _get_tool_and_context(
invocation_context,
function_call,
tools_dict,
tool_confirmation,
)

with tracer.start_as_current_span(f'execute_tool {tool.name}'):
Expand Down Expand Up @@ -567,6 +637,7 @@ def _get_tool_and_context(
invocation_context: InvocationContext,
function_call: types.FunctionCall,
tools_dict: dict[str, BaseTool],
tool_confirmation: Optional[ToolConfirmation] = None,
):
if function_call.name not in tools_dict:
raise ValueError(
Expand All @@ -576,6 +647,7 @@ def _get_tool_and_context(
tool_context = ToolContext(
invocation_context=invocation_context,
function_call_id=function_call.id,
tool_confirmation=tool_confirmation,
)

tool = tools_dict[function_call.name]
Expand Down
Loading