Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 27 additions & 22 deletions src/google/adk/flows/llm_flows/request_confirmation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,27 @@ async def run_async(
if not events:
return

request_confirmation_function_responses = (
dict()
) # {function call id, tool confirmation}
request_confirmation_function_responses = dict()

confirmation_event_index = -1

# Helper to unwrap redundant response envelopes and decode the innermost JSON.
def _parse_tool_confirmation_payload(payload: 'Any') -> 'Any':
while (
isinstance(payload, dict)
and len(payload) == 1
and 'response' in payload
):
payload = payload['response']
if isinstance(payload, str):
try:
payload = json.loads(payload)
except json.JSONDecodeError as exc:
raise ValueError(
'Failed to decode tool confirmation payload.'
) from exc
return payload

for k in range(len(events) - 1, -1, -1):
event = events[k]
# Find the first event authored by user
Expand All @@ -71,22 +87,13 @@ async def run_async(
if function_response.name != REQUEST_CONFIRMATION_FUNCTION_CALL_NAME:
continue

# Find the FunctionResponse event that contains the user provided tool
# confirmation
if (
confirmation_payload = _parse_tool_confirmation_payload(
function_response.response
and len(function_response.response.values()) == 1
and 'response' in function_response.response.keys()
):
# ADK web client will send a request that is always encapsulated in a
# 'response' key.
tool_confirmation = ToolConfirmation.model_validate(
json.loads(function_response.response['response'])
)
else:
tool_confirmation = ToolConfirmation.model_validate(
function_response.response
)
)

tool_confirmation = ToolConfirmation.model_validate(
confirmation_payload
)
request_confirmation_function_responses[function_response.id] = (
tool_confirmation
)
Expand All @@ -104,10 +111,8 @@ async def run_async(
if not function_calls:
continue

tools_to_resume_with_confirmation = (
dict()
) # {Function call id, tool confirmation}
tools_to_resume_with_args = dict() # {Function call id, function calls}
tools_to_resume_with_confirmation = dict()
tools_to_resume_with_args = dict()

for function_call in function_calls:
if (
Expand Down
98 changes: 97 additions & 1 deletion tests/unittests/flows/llm_flows/test_request_confirmation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from unittest.mock import patch

from google.adk.agents.llm_agent import LlmAgent
Expand Down Expand Up @@ -210,6 +209,103 @@ async def test_request_confirmation_processor_success():
) # tool_confirmation_dict


@pytest.mark.asyncio
async def test_request_confirmation_processor_doubly_wrapped_response():
"""Test confirmation parsing when responses are nested under multiple keys."""
agent = LlmAgent(name="test_agent", tools=[mock_tool])
invocation_context = await testing_utils.create_invocation_context(
agent=agent
)
llm_request = LlmRequest()

original_function_call = types.FunctionCall(
name=MOCK_TOOL_NAME, args={"param1": "test"}, id=MOCK_FUNCTION_CALL_ID
)

tool_confirmation = ToolConfirmation(confirmed=False, hint="test hint")
tool_confirmation_args = {
"originalFunctionCall": original_function_call.model_dump(
exclude_none=True, by_alias=True
),
"toolConfirmation": tool_confirmation.model_dump(
by_alias=True, exclude_none=True
),
}

invocation_context.session.events.append(
Event(
author="agent",
content=types.Content(
parts=[
types.Part(
function_call=types.FunctionCall(
name=functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
args=tool_confirmation_args,
id=MOCK_CONFIRMATION_FUNCTION_CALL_ID,
)
)
]
),
)
)

user_confirmation = ToolConfirmation(confirmed=True)
invocation_context.session.events.append(
Event(
author="user",
content=types.Content(
parts=[
types.Part(
function_response=types.FunctionResponse(
name=functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
id=MOCK_CONFIRMATION_FUNCTION_CALL_ID,
response={
"response": {
"response": user_confirmation.model_dump_json()
}
},
)
)
]
),
)
)

expected_event = Event(
author="agent",
content=types.Content(
parts=[
types.Part(
function_response=types.FunctionResponse(
name=MOCK_TOOL_NAME,
id=MOCK_FUNCTION_CALL_ID,
response={"result": "Mock tool result with test"},
)
)
]
),
)

with patch(
"google.adk.flows.llm_flows.functions.handle_function_call_list_async"
) as mock_handle_function_call_list_async:
mock_handle_function_call_list_async.return_value = expected_event

events = []
async for event in request_processor.run_async(
invocation_context, llm_request
):
events.append(event)

assert len(events) == 1
assert events[0] == expected_event

args, _ = mock_handle_function_call_list_async.call_args
assert (
args[4][MOCK_FUNCTION_CALL_ID] == user_confirmation
) # tool_confirmation_dict


@pytest.mark.asyncio
async def test_request_confirmation_processor_tool_not_confirmed():
"""Test when the tool execution is not confirmed by the user."""
Expand Down
Loading